aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Marius Venstad <jonmv@users.noreply.github.com>2018-01-02 12:12:08 +0100
committerGitHub <noreply@github.com>2018-01-02 12:12:08 +0100
commit95ae8c562a8826f03bd2faad82b0ffb754133342 (patch)
tree35545093828437f6d7704b4b6c3646a39ff50a00
parent21dd7e03056a77e1de75e6e95413c3b00e6615ec (diff)
parentb73a4d2c8b5e7ae83743b10b8f21836811e5dff4 (diff)
Merge branch 'master' into jvenstad/zone-cleanup-4
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java23
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java5
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java20
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java2
-rwxr-xr-xconfig-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java1
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java2
-rw-r--r--configd/src/apps/sentinel/service.cpp4
-rw-r--r--configdefinitions/src/vespa/configserver.def1
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java1
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/CombinedLegacyDistribution.java5
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionProvider.java6
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandler.java8
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpHandler.java6
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandler.java8
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListNamedConfigsHandler.java10
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/SessionHandler.java7
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java8
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HostHandler.java7
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandler.java6
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandler.java6
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListNamedConfigsHandler.java6
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandler.java8
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandler.java9
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandler.java10
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandler.java9
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandler.java9
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/TenantHandler.java7
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java10
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandlerTest.java10
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpHandlerTest.java6
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandlerTest.java15
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/SessionExampleHandlerTest.java2
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java3
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java9
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HostHandlerTest.java6
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandlerTest.java6
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandlerTest.java12
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java8
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java9
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java16
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java11
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java11
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/TenantHandlerTest.java4
-rw-r--r--container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java3
-rw-r--r--container-accesslogging/src/test/java/com/yahoo/container/logging/JSONLogTestCase.java5
-rw-r--r--container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java35
-rw-r--r--container-dev/pom.xml24
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/ConfiguredApplication.java2
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java3
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java60
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/component/Deconstructor.java1
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java12
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java8
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java2
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java8
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java2
-rw-r--r--container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java2
-rw-r--r--container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java29
-rw-r--r--container/pom.xml12
-rw-r--r--controller-api/pom.xml80
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifier.java41
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzPrincipal.java9
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtils.java40
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/InvalidTokenException.java2
-rw-r--r--controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifierTest.java82
-rw-r--r--controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtilsTest.java21
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java13
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java6
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobStatus.java43
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilter.java52
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java33
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java6
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java110
-rw-r--r--document/src/main/java/com/yahoo/document/DataType.java8
-rw-r--r--document/src/main/java/com/yahoo/document/DocumentTypeManager.java12
-rw-r--r--document/src/main/java/com/yahoo/document/TensorDataType.java6
-rw-r--r--document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java4
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java8
-rw-r--r--document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java2
-rw-r--r--document/src/vespa/document/bucket/bucketspace.h18
-rw-r--r--document/src/vespa/document/select/CMakeLists.txt12
-rw-r--r--document/src/vespa/document/select/context.cpp12
-rw-r--r--document/src/vespa/document/select/grammar/lexer.ll4
-rw-r--r--document/src/vespa/document/select/grammar/parser.yy3
-rw-r--r--document/src/vespa/document/test/make_bucket_space.cpp6
-rwxr-xr-xdocumentapi/src/main/java/com/yahoo/documentapi/SyncParameters.java14
-rwxr-xr-xdocumentapi/src/main/java/com/yahoo/documentapi/SyncSession.java9
-rwxr-xr-xdocumentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusSyncSession.java16
-rw-r--r--eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp12
-rw-r--r--eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp2
-rw-r--r--eval/src/tests/tensor/sparse_tensor_builder/sparse_tensor_builder_test.cpp26
-rw-r--r--eval/src/vespa/eval/eval/operation.h8
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor.cpp4
-rw-r--r--eval/src/vespa/eval/eval/value_type.cpp4
-rw-r--r--eval/src/vespa/eval/eval/value_type.h13
-rw-r--r--eval/src/vespa/eval/tensor/cell_function.h6
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp7
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h7
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor.cpp8
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor.h16
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp61
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h54
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_apply.h16
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp8
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp2
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h7
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.cpp27
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h37
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.cpp7
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.h10
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.h9
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp8
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.h8
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp7
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h7
-rw-r--r--eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.cpp9
-rw-r--r--eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.h9
-rw-r--r--eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.cpp9
-rw-r--r--eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h9
-rw-r--r--eval/src/vespa/eval/tensor/direct_tensor_builder.h6
-rw-r--r--eval/src/vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h11
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp12
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.h36
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_builder.h37
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.cpp33
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h29
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_decoder.h9
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_padder.h12
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.cpp19
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.h16
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_ref.h15
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.h15
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.hpp16
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.cpp11
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.h21
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.cpp21
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.h6
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_reduce.hpp9
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.cpp7
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.h11
-rw-r--r--eval/src/vespa/eval/tensor/tensor_address.h6
-rw-r--r--eval/src/vespa/eval/tensor/tensor_address_builder.h6
-rw-r--r--eval/src/vespa/eval/tensor/tensor_apply.cpp7
-rw-r--r--eval/src/vespa/eval/tensor/tensor_apply.h6
-rw-r--r--eval/src/vespa/eval/tensor/tensor_mapper.cpp1
-rw-r--r--eval/src/vespa/eval/tensor/tensor_operation.h6
-rw-r--r--eval/src/vespa/eval/tensor/tensor_visitor.h6
-rw-r--r--eval/src/vespa/eval/tensor/types.h6
-rw-r--r--filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java5
-rw-r--r--filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java2
-rw-r--r--filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReceiver.java10
-rw-r--r--filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java8
-rw-r--r--jaxrs_client_utils/src/main/java/com/yahoo/vespa/jaxrs/client/JerseyJaxRsClientFactory.java16
-rw-r--r--jdisc_core/src/main/java/com/yahoo/jdisc/core/StandaloneMain.java5
-rw-r--r--jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java8
-rw-r--r--metrics/src/vespa/metrics/metrictimer.h2
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java3
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminProvider.java3
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutor.java65
-rw-r--r--node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutorTest.java4
-rw-r--r--node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoreCollector.java5
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java4
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeSpec.java38
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisioningTest.java17
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java10
-rw-r--r--persistence/src/tests/spi/CMakeLists.txt1
-rw-r--r--persistence/src/tests/spi/fixed_bucket_spaces_test.cpp64
-rw-r--r--persistence/src/vespa/persistence/conformancetest/conformancetest.cpp1
-rw-r--r--persistence/src/vespa/persistence/spi/CMakeLists.txt13
-rw-r--r--persistence/src/vespa/persistence/spi/context.h8
-rw-r--r--persistence/src/vespa/persistence/spi/fixed_bucket_spaces.cpp33
-rw-r--r--persistence/src/vespa/persistence/spi/fixed_bucket_spaces.h30
-rw-r--r--persistence/src/vespa/persistence/spi/metricpersistenceprovider.cpp2
-rw-r--r--persistence/src/vespa/persistence/spi/metricpersistenceprovider.h10
-rw-r--r--pom.xml935
-rw-r--r--searchcore/src/tests/proton/persistenceengine/persistenceengine_test.cpp3
-rw-r--r--searchcore/src/vespa/searchcore/proton/persistenceengine/persistenceengine.cpp6
-rw-r--r--searchlib/pom.xml15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java44
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java102
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java160
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java94
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java145
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java36
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj2
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py89
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt5039
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001bin0 -> 31400 bytes
-rw-r--r--searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.indexbin0 -> 159 bytes
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java13
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java118
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java4
-rw-r--r--searchlib/src/vespa/searchlib/attribute/postingchange.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/common/foregroundtaskexecutor.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/common/sequencedtaskexecutor.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/docstore/visitcache.cpp3
-rw-r--r--searchlib/src/vespa/searchlib/docstore/visitcache.h2
-rw-r--r--staging_vespalib/src/vespa/vespalib/stllike/lrucache_map.hpp5
-rw-r--r--storage/src/tests/distributor/blockingoperationstartertest.cpp2
-rw-r--r--storage/src/tests/distributor/bucketdbupdatertest.cpp7
-rw-r--r--storage/src/tests/distributor/externaloperationhandlertest.cpp10
-rw-r--r--storage/src/tests/distributor/getoperationtest.cpp11
-rw-r--r--storage/src/tests/distributor/idealstatemanagertest.cpp12
-rw-r--r--storage/src/tests/distributor/maintenanceschedulertest.cpp8
-rw-r--r--storage/src/tests/distributor/messagesenderstub.h1
-rw-r--r--storage/src/tests/distributor/pendingmessagetrackertest.cpp6
-rw-r--r--storage/src/tests/distributor/putoperationtest.cpp2
-rw-r--r--storage/src/tests/distributor/simplemaintenancescannertest.cpp18
-rw-r--r--storage/src/tests/distributor/throttlingoperationstartertest.cpp2
-rw-r--r--storage/src/tests/distributor/visitoroperationtest.cpp7
-rw-r--r--storage/src/tests/persistence/splitbitdetectortest.cpp1
-rw-r--r--storage/src/tests/storageserver/CMakeLists.txt1
-rw-r--r--storage/src/tests/storageserver/configurable_bucket_resolver_test.cpp137
-rw-r--r--storage/src/tests/storageserver/documentapiconvertertest.cpp31
-rw-r--r--storage/src/vespa/storage/common/bucketmessages.cpp1
-rw-r--r--storage/src/vespa/storage/common/messagesender.h5
-rw-r--r--storage/src/vespa/storage/common/storagecomponent.cpp40
-rw-r--r--storage/src/vespa/storage/common/storagecomponent.h18
-rw-r--r--storage/src/vespa/storage/distributor/bucketdbupdater.cpp2
-rw-r--r--storage/src/vespa/storage/distributor/bucketdbupdater.h3
-rw-r--r--storage/src/vespa/storage/distributor/bucketgctimecalculator.h7
-rw-r--r--storage/src/vespa/storage/distributor/bucketownership.h11
-rw-r--r--storage/src/vespa/storage/distributor/distributor.cpp49
-rw-r--r--storage/src/vespa/storage/distributor/distributorinterface.h32
-rw-r--r--storage/src/vespa/storage/distributor/distributormessagesender.h21
-rw-r--r--storage/src/vespa/storage/distributor/idealstatemanager.cpp2
-rw-r--r--storage/src/vespa/storage/distributor/idealstatemanager.h30
-rw-r--r--storage/src/vespa/storage/distributor/messagetracker.cpp11
-rw-r--r--storage/src/vespa/storage/distributor/messagetracker.h10
-rw-r--r--storage/src/vespa/storage/distributor/operations/external/putoperation.h12
-rw-r--r--storage/src/vespa/storage/distributor/operations/external/statbucketoperation.h18
-rw-r--r--storage/src/vespa/storage/distributor/operations/external/twophaseupdateoperation.cpp13
-rw-r--r--storage/src/vespa/storage/distributor/operations/external/visitoroperation.h9
-rw-r--r--storage/src/vespa/storage/distributor/operations/idealstate/idealstateoperation.cpp20
-rw-r--r--storage/src/vespa/storage/distributor/operations/idealstate/mergeoperation.cpp2
-rw-r--r--storage/src/vespa/storage/distributor/operations/idealstate/setbucketstateoperation.cpp1
-rw-r--r--storage/src/vespa/storage/distributor/operationtargetresolver.h6
-rw-r--r--storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.cpp26
-rw-r--r--storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h9
-rw-r--r--storage/src/vespa/storage/persistence/splitbitdetector.h1
-rw-r--r--storage/src/vespa/storage/storageserver/CMakeLists.txt1
-rw-r--r--storage/src/vespa/storage/storageserver/communicationmanager.cpp3
-rw-r--r--storage/src/vespa/storage/storageserver/communicationmanager.h1
-rw-r--r--storage/src/vespa/storage/storageserver/configurable_bucket_resolver.cpp36
-rw-r--r--storage/src/vespa/storage/storageserver/configurable_bucket_resolver.h36
-rw-r--r--storage/src/vespa/storage/storageserver/documentapiconverter.cpp34
-rw-r--r--storage/src/vespa/storage/storageserver/documentapiconverter.h10
-rw-r--r--storage/src/vespa/storage/storageserver/storagenode.cpp64
-rw-r--r--storage/src/vespa/storage/storageserver/storagenode.h26
-rw-r--r--storageapi/src/vespa/storageapi/message/state.h6
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java12
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/net/HostName.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java22
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java352
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java20
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java79
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java30
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java56
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java62
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java64
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java30
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java68
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java41
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java119
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java13
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java6
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java14
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java21
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java10
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java97
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java6
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java2
-rw-r--r--vespalib/src/vespa/vespalib/stllike/hash_map.h2
-rw-r--r--vespalib/src/vespa/vespalib/stllike/hash_map.hpp26
-rw-r--r--vespalib/src/vespa/vespalib/stllike/hashtable.h96
-rw-r--r--vespalib/src/vespa/vespalib/stllike/hashtable.hpp38
-rw-r--r--zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java24
-rw-r--r--zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java (renamed from zkfacade/src/test/java/com/yahoo/vespa/zookeeper/CuratorTest.java)22
316 files changed, 9650 insertions, 2023 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
index c1a786194a2..7506c884715 100644
--- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
+++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
@@ -7,6 +7,7 @@ import com.yahoo.config.provision.Zone;
import com.yahoo.path.Path;
import com.yahoo.io.IOUtils;
import com.yahoo.io.reader.NamedReader;
+import com.yahoo.path.Path;
import com.yahoo.text.XML;
import com.yahoo.vespa.config.ConfigDefinitionKey;
import org.w3c.dom.Element;
@@ -14,8 +15,17 @@ import org.xml.sax.SAXException;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.transform.TransformerException;
-import java.io.*;
-import java.util.*;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.Reader;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
@@ -228,9 +238,9 @@ public interface ApplicationPackage {
throw new UnsupportedOperationException("This application package cannot write its metadata");
}
- /**
- * Returns the single host allocation info of this, or an empty map if no allocation is available
- *
+ /**
+ * Returns the single host allocation info of this, or an empty map if no allocation is available
+ *
* @deprecated please use #getAllocatedHosts
*/
// TODO: Remove on Vespa 7
@@ -261,7 +271,8 @@ public interface ApplicationPackage {
*
* @return A new application package instance pointing to a new location
*/
- default ApplicationPackage preprocess(Zone zone, RuleConfigDeriver ruleConfigDeriver, DeployLogger logger) throws IOException, TransformerException, ParserConfigurationException, SAXException {
+ default ApplicationPackage preprocess(Zone zone, RuleConfigDeriver ruleConfigDeriver, DeployLogger logger)
+ throws IOException, TransformerException, ParserConfigurationException, SAXException {
throw new UnsupportedOperationException("This application package does not support preprocessing");
}
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java b/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java
index cee501841b4..61cab2f6ce7 100644
--- a/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java
+++ b/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java
@@ -4,10 +4,9 @@ package com.yahoo.config.application.api;
import java.util.logging.Level;
/**
- * Used during application deployment to persist and propagate messages to end user
+ * Used during application deployment to propagate messages to the end user
*
- * @author lulf
- * @since 5.1
+ * @author Ulf Lillengen
*/
public interface DeployLogger {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java
index 69e353ceb35..65176006a2a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java
@@ -118,7 +118,7 @@ public class TensorTransformer extends ExpressionTransformer {
private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) {
ExpressionNode arg1 = node.children().get(0);
ExpressionNode arg2 = node.children().get(1);
-
+
TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1);
Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name());
String dimension = ((ReferenceNode) arg2).getName();
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java
index c52a5dc465d..f932265cb93 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java
@@ -259,7 +259,7 @@ public final class Attribute implements Cloneable, Serializable {
throw new IllegalArgumentException("Field " + fieldType + " not supported in convertCollectionType");
}
}
-
+
private static Optional<TensorType> convertTensorType(DataType fieldType) {
if ( ! ( fieldType instanceof TensorDataType)) return Optional.empty();
return Optional.of(((TensorDataType)fieldType).getTensorType());
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java
index c8918f39834..8b6df1a87db 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java
@@ -29,7 +29,7 @@ public class ImmutableImportedSDField implements ImmutableSDField {
@Override
public <T extends Expression> boolean containsExpression(Class<T> searchFor) {
- throw createUnsupportedException();
+ throw createUnsupportedException(searchFor.getSimpleName());
}
@Override
@@ -79,9 +79,9 @@ public class ImmutableImportedSDField implements ImmutableSDField {
@Override
public Index getIndex(String name) {
- if (!importedField.fieldName().equals(name)) {
+ if ( ! importedField.fieldName().equals(name)) {
throw new IllegalArgumentException("Getting an index (" + name + ") with different name than the imported field ("
- + importedField.fieldName() + ") is not supported");
+ + importedField.fieldName() + ") is not supported");
}
String targetIndexName = importedField.targetField().getName();
return importedField.targetField().getIndex(targetIndexName);
@@ -104,7 +104,7 @@ public class ImmutableImportedSDField implements ImmutableSDField {
@Override
public ScriptExpression getIndexingScript() {
- throw createUnsupportedException();
+ throw createUnsupportedException("indexing");
}
@Override
@@ -119,12 +119,12 @@ public class ImmutableImportedSDField implements ImmutableSDField {
@Override
public ImmutableSDField getStructField(String name) {
- throw createUnsupportedException();
+ throw createUnsupportedException("struct");
}
@Override
public Collection<? extends ImmutableSDField> getStructFields() {
- throw createUnsupportedException();
+ throw createUnsupportedException("struct");
}
@Override
@@ -134,12 +134,12 @@ public class ImmutableImportedSDField implements ImmutableSDField {
@Override
public Stemming getStemming(Search search) {
- throw createUnsupportedException();
+ throw createUnsupportedException("stemming");
}
@Override
public Ranking getRanking() {
- throw createUnsupportedException();
+ throw createUnsupportedException("ranking");
}
@Override
@@ -158,8 +158,8 @@ public class ImmutableImportedSDField implements ImmutableSDField {
importedField.targetField().getDataType());
}
- private static UnsupportedOperationException createUnsupportedException() {
- return new UnsupportedOperationException("This aspect is not meaningful or relevant for an imported field.");
+ private static UnsupportedOperationException createUnsupportedException(String aspect) {
+ return new UnsupportedOperationException("'" + aspect + "' is not meaningful or relevant for an imported field.");
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java b/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java
index 96a9448739a..9368d6aaa39 100644
--- a/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java
+++ b/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java
@@ -20,7 +20,7 @@ import java.util.Set;
*/
public class DocumentManager {
- public DocumentmanagerConfig.Builder produce(DocumentModel model,
+ public DocumentmanagerConfig.Builder produce(DocumentModel model,
DocumentmanagerConfig.Builder documentConfigBuilder) {
documentConfigBuilder.enablecompression(false);
Set<DataType> handled = new HashSet<>();
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
index 6eeb12ffdd9..ce3c04f41f7 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
@@ -45,7 +45,7 @@ public class ConstantTensorJsonValidator {
throw new IllegalArgumentException("Ranking constant file names must end with either '.json' or '.json.lz4'");
}
}
-
+
private void validateTensor(TensorType type, Reader tensorData) {
wrapIOException(() -> {
this.parser = jsonFactory.createParser(tensorData);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java
index 4a9310799aa..c686f023d5b 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java
@@ -64,7 +64,7 @@ public class RankingConstantsValidator extends Validator {
private void validateRankingConstant(RankingConstant rankingConstant, ApplicationPackage applicationPackage) throws FileNotFoundException {
ApplicationFile tensorApplicationFile = applicationPackage.getFile(Path.fromString(rankingConstant.getFileName()));
- new ConstantTensorJsonValidator().validate(rankingConstant.getFileName(),
+ new ConstantTensorJsonValidator().validate(rankingConstant.getFileName(),
rankingConstant.getTensorType(),
tensorApplicationFile.createReader());
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java
index 4383e55e45d..28a54771c21 100755
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java
@@ -220,6 +220,7 @@ public final class ContainerCluster
addSimpleComponent("com.yahoo.container.jdisc.metric.MetricConsumerProviderProvider");
addSimpleComponent("com.yahoo.container.jdisc.metric.MetricProvider");
addSimpleComponent("com.yahoo.container.jdisc.metric.MetricUpdater");
+ addSimpleComponent(com.yahoo.container.jdisc.LoggingRequestHandler.Context.class);
addSimpleComponent(com.yahoo.metrics.simple.MetricManager.class.getName(), null, MetricProperties.BUNDLE_SYMBOLIC_NAME);
addSimpleComponent(com.yahoo.metrics.simple.jdisc.JdiscMetricsFactory.class.getName(), null, MetricProperties.BUNDLE_SYMBOLIC_NAME);
addSimpleComponent("com.yahoo.container.jdisc.state.StateMonitor");
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java
index a5b7d67e377..e1675007bbc 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java
@@ -8,6 +8,7 @@ import com.yahoo.vespa.model.content.Redundancy;
* Builds redundancy config for a content cluster.
*/
public class RedundancyBuilder {
+
Redundancy build(ModelElement clusterXml) {
Integer initialRedundancy = 2;
Integer finalRedundancy = 3;
@@ -37,4 +38,5 @@ public class RedundancyBuilder {
return new Redundancy(initialRedundancy, finalRedundancy, readyCopies);
}
+
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
index 9407c21fee8..960a3b7d6db 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
@@ -173,7 +173,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.tensor3").isPresent());
assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.numeric").isPresent());
}
-
+
private static Optional<String> findProperty(List<Pair<String, String>> properties, String key) {
for (Pair<String, String> property : properties)
if (property.getFirst().equals(key))
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java
index 7cd00e155bb..4600f6ae4c6 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java
@@ -123,7 +123,7 @@ public class ExportingTestCase extends AbstractExportingTestCase {
public void testIndexinfoFieldsets() throws IOException, ParseException {
assertCorrectDeriving("indexinfo_fieldsets");
}
-
+
@Test
public void testStreamingJuniper() throws IOException, ParseException {
assertCorrectDeriving("streamingjuniper");
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index 12bdd8d2b5c..e5693d24f0f 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -202,5 +202,5 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
}
return b.toString();
}
-
+
}
diff --git a/configd/src/apps/sentinel/service.cpp b/configd/src/apps/sentinel/service.cpp
index 3c762a957ec..5633c356bc7 100644
--- a/configd/src/apps/sentinel/service.cpp
+++ b/configd/src/apps/sentinel/service.cpp
@@ -113,8 +113,7 @@ Service::terminate(bool catchable, bool dumpState)
ret == 0 ? "OK" : strerror(errno));
return ret;
} else {
- setState(KILLING);
- if (dumpState) {
+ if (dumpState && _state != KILLING) {
vespalib::string pstackCmd = make_string("pstack %d > %s/%s.pstack.%d",
_pid, getVespaTempDir().c_str(), name().c_str(), _pid);
LOG(info, "%s:%d failed to stop. Stack dumped at %s", name().c_str(), _pid, pstackCmd.c_str());
@@ -123,6 +122,7 @@ Service::terminate(bool catchable, bool dumpState)
LOG(warning, "'%s' failed with return value %d", pstackCmd.c_str(), pstackRet);
}
}
+ setState(KILLING);
kill(_pid, SIGCONT); // if it was stopped for some reason
int ret = kill(_pid, SIGKILL);
LOG(debug, "%s: kill -SIGKILL %d: %s", name().c_str(), (int)_pid,
diff --git a/configdefinitions/src/vespa/configserver.def b/configdefinitions/src/vespa/configserver.def
index cbc2317da2d..9072a20c006 100644
--- a/configdefinitions/src/vespa/configserver.def
+++ b/configdefinitions/src/vespa/configserver.def
@@ -27,6 +27,7 @@ payloadCompressionType enum { UNCOMPRESSED, LZ4 } default=LZ4
serverId string default="localhost"
hostedVespa bool default=false
numParallelTenantLoaders int default=4
+zookeeperLocalhostAffinity bool default=false
# Zone information
environment string default="prod"
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java
index 39cd4629ff0..925f8324b30 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java
@@ -35,7 +35,6 @@ public class ApplicationConvergenceChecker extends AbstractComponent {
private final static Set<String> serviceTypesToCheck = new HashSet<>(Arrays.asList(
"container",
- "container-clustercontroller",
"qrserver",
"docprocservice",
"searchnode",
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/CombinedLegacyDistribution.java b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/CombinedLegacyDistribution.java
index 1046ed93491..819f9a9d5d6 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/CombinedLegacyDistribution.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/CombinedLegacyDistribution.java
@@ -21,11 +21,12 @@ import java.util.logging.Logger;
public class CombinedLegacyDistribution implements FileDistribution {
private final static Logger log = Logger.getLogger(CombinedLegacyDistribution.class.getName());
- private final Supervisor supervisor = new Supervisor(new Transport());
+ private final Supervisor supervisor;
private final FileDistribution legacy;
private final boolean disableFileDistributor;
- CombinedLegacyDistribution(FileDBHandler legacy, boolean disableFileDistributor) {
+ CombinedLegacyDistribution(Supervisor supervisor, FileDBHandler legacy, boolean disableFileDistributor) {
+ this.supervisor = supervisor;
this.legacy = legacy;
this.disableFileDistributor = disableFileDistributor;
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionProvider.java b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionProvider.java
index 38fa3087f88..cd3f0f7f167 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionProvider.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionProvider.java
@@ -4,6 +4,7 @@ package com.yahoo.vespa.config.server.filedistribution;
import com.yahoo.config.FileReference;
import com.yahoo.config.model.api.FileDistribution;
import com.yahoo.config.application.api.FileRegistry;
+import com.yahoo.jrt.Supervisor;
import com.yahoo.vespa.filedistribution.FileDistributionManager;
import java.io.File;
@@ -35,16 +36,17 @@ public class FileDistributionProvider {
}
}
- public FileDistributionProvider(File applicationDir, String zooKeepersSpec,
+ public FileDistributionProvider(Supervisor supervisor, File applicationDir, String zooKeepersSpec,
String applicationId, Lock fileDistributionLock,
boolean disableFileDistributor) {
ensureDirExists(FileDistribution.getDefaultFileDBPath());
final FileDistributionManager manager = new FileDistributionManager(
FileDistribution.getDefaultFileDBPath(), applicationDir,
zooKeepersSpec, applicationId, fileDistributionLock);
- this.fileDistribution = new CombinedLegacyDistribution(new FileDBHandler(manager), disableFileDistributor);
+ this.fileDistribution = new CombinedLegacyDistribution(supervisor, new FileDBHandler(manager), disableFileDistributor);
this.fileRegistry = new CombinedLegacyRegistry(new FileDBRegistry(new ManagerWrapper(manager)),
new FileDBRegistry(new ApplicationFileManager(applicationDir, new FileDirectory())));
+
}
// For testing only
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandler.java
index 3ec4d1b6e46..94707635950 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandler.java
@@ -23,14 +23,14 @@ import java.util.concurrent.Executor;
public class HttpGetConfigHandler extends HttpHandler {
private final RequestHandler requestHandler;
- public HttpGetConfigHandler(Executor executor, RequestHandler requestHandler, AccessLog accessLog) {
- super(executor, accessLog);
+ public HttpGetConfigHandler(HttpHandler.Context ctx, RequestHandler requestHandler) {
+ super(ctx);
this.requestHandler = requestHandler;
}
@Inject
- public HttpGetConfigHandler(Executor executor, Tenants tenants, AccessLog accesslog) {
- this(executor, tenants.defaultTenant().getRequestHandler(), accesslog);
+ public HttpGetConfigHandler(HttpHandler.Context ctx, Tenants tenants) {
+ this(ctx, tenants.defaultTenant().getRequestHandler());
}
@Override
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpHandler.java
index cc78c2715e2..e8db448b245 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpHandler.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.config.server.http;
+import com.google.inject.Inject;
+
import com.yahoo.config.provision.ApplicationLockException;
import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
@@ -25,8 +27,8 @@ import java.util.concurrent.Executor;
*/
public class HttpHandler extends LoggingRequestHandler {
- public HttpHandler(Executor executor, AccessLog accessLog) {
- super(executor, accessLog);
+ public HttpHandler(HttpHandler.Context ctx) {
+ super(ctx);
}
@Override
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandler.java
index 5ea0b38c110..64361c0771c 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandler.java
@@ -32,12 +32,12 @@ public class HttpListConfigsHandler extends HttpHandler {
private final RequestHandler requestHandler;
@Inject
- public HttpListConfigsHandler(Executor executor, AccessLog accessLog, Tenants tenants) {
- this(executor, accessLog, tenants.defaultTenant().getRequestHandler());
+ public HttpListConfigsHandler(HttpHandler.Context ctx, Tenants tenants) {
+ this(ctx, tenants.defaultTenant().getRequestHandler());
}
- public HttpListConfigsHandler(Executor executor, AccessLog accessLog, RequestHandler requestHandler) {
- super(executor, accessLog);
+ public HttpListConfigsHandler(HttpHandler.Context ctx, RequestHandler requestHandler) {
+ super(ctx);
this.requestHandler = requestHandler;
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListNamedConfigsHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListNamedConfigsHandler.java
index 7c51fd131ff..81163d79341 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListNamedConfigsHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListNamedConfigsHandler.java
@@ -25,14 +25,16 @@ import java.util.concurrent.Executor;
public class HttpListNamedConfigsHandler extends HttpHandler {
private final RequestHandler requestHandler;
- public HttpListNamedConfigsHandler(Executor executor, RequestHandler requestHandler, AccessLog accessLog) {
- super(executor, accessLog);
+ public HttpListNamedConfigsHandler(HttpHandler.Context ctx,
+ RequestHandler requestHandler) {
+ super(ctx);
this.requestHandler = requestHandler;
}
@Inject
- public HttpListNamedConfigsHandler(Executor executor, Tenants tenants, AccessLog accessLog) {
- this(executor, tenants.defaultTenant().getRequestHandler(), accessLog);
+ public HttpListNamedConfigsHandler(HttpHandler.Context ctx,
+ Tenants tenants) {
+ this(ctx, tenants.defaultTenant().getRequestHandler());
}
@Override
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/SessionHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/SessionHandler.java
index 40ffc8e9da3..5acb6e81a83 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/SessionHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/SessionHandler.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.config.server.http;
+import com.google.inject.Inject;
+
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.logging.AccessLog;
@@ -27,8 +29,9 @@ public class SessionHandler extends HttpHandler {
protected final ApplicationRepository applicationRepository;
- public SessionHandler(Executor executor, AccessLog accessLog, ApplicationRepository applicationRepository) {
- super(executor, accessLog);
+ public SessionHandler(HttpHandler.Context ctx, ApplicationRepository applicationRepository)
+ {
+ super(ctx);
this.applicationRepository = applicationRepository;
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java
index ef122147d79..819f1a35cf3 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.config.server.http.v2;
+import com.google.inject.Inject;
+
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.ApplicationName;
@@ -37,11 +39,11 @@ public class ApplicationHandler extends HttpHandler {
private final Zone zone;
private final ApplicationRepository applicationRepository;
- public ApplicationHandler(Executor executor,
- AccessLog accessLog,
+ @Inject
+ public ApplicationHandler(HttpHandler.Context ctx,
Zone zone,
ApplicationRepository applicationRepository) {
- super(executor, accessLog);
+ super(ctx);
this.zone = zone;
this.applicationRepository = applicationRepository;
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HostHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HostHandler.java
index 2acaa67baef..13933544ad1 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HostHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HostHandler.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.config.server.http.v2;
+import com.google.inject.Inject;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.TenantName;
import com.yahoo.config.provision.Zone;
@@ -28,8 +29,10 @@ public class HostHandler extends HttpHandler {
final HostRegistries hostRegistries;
private final Zone zone;
- public HostHandler(Executor executor, AccessLog accessLog, GlobalComponentRegistry globalComponentRegistry) {
- super(executor, accessLog);
+ @Inject
+ public HostHandler(HttpHandler.Context ctx,
+ GlobalComponentRegistry globalComponentRegistry) {
+ super(ctx);
this.hostRegistries = globalComponentRegistry.getHostRegistries();
this.zone = globalComponentRegistry.getZone();
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandler.java
index 1b566fbb9c5..0ca720c9710 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandler.java
@@ -27,8 +27,10 @@ public class HttpGetConfigHandler extends HttpHandler {
private final Tenants tenants;
@Inject
- public HttpGetConfigHandler(Executor executor, AccessLog accesslog, Tenants tenants) {
- super(executor, accesslog);
+ public HttpGetConfigHandler(HttpHandler.Context ctx,
+ Tenants tenants)
+ {
+ super(ctx);
this.tenants = tenants;
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandler.java
index ea3a1a2c9f4..2a9e2b1ecf4 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandler.java
@@ -34,8 +34,10 @@ public class HttpListConfigsHandler extends HttpHandler {
private final Zone zone;
@Inject
- public HttpListConfigsHandler(Executor executor, AccessLog accesslog, Tenants tenants, Zone zone) {
- super(executor, accesslog);
+ public HttpListConfigsHandler(HttpHandler.Context ctx,
+ Tenants tenants, Zone zone)
+ {
+ super(ctx);
this.tenants = tenants;
this.zone = zone;
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListNamedConfigsHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListNamedConfigsHandler.java
index 2262b8bc722..0a55d3585e0 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListNamedConfigsHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListNamedConfigsHandler.java
@@ -29,8 +29,10 @@ public class HttpListNamedConfigsHandler extends HttpHandler {
private final Zone zone;
@Inject
- public HttpListNamedConfigsHandler(Executor executor, AccessLog accesslog, Tenants tenants, Zone zone) {
- super(executor, accesslog);
+ public HttpListNamedConfigsHandler(HttpHandler.Context ctx,
+ Tenants tenants, Zone zone)
+ {
+ super(ctx);
this.tenants = tenants;
this.zone = zone;
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandler.java
index 79f551c270b..42872881088 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandler.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.config.server.http.v2;
+import com.google.inject.Inject;
import com.google.common.base.Function;
import com.google.common.collect.Collections2;
import com.yahoo.config.provision.TenantName;
@@ -29,8 +30,11 @@ import java.util.concurrent.Executor;
public class ListApplicationsHandler extends HttpHandler {
private final Tenants tenants;
private final Zone zone;
- public ListApplicationsHandler(Executor executor, AccessLog accessLog, Tenants tenants, Zone zone) {
- super(executor, accessLog);
+
+ @Inject
+ public ListApplicationsHandler(HttpHandler.Context ctx,
+ Tenants tenants, Zone zone) {
+ super(ctx);
this.tenants = tenants;
this.zone = zone;
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandler.java
index f1c75ff0a01..b2330ebd97f 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandler.java
@@ -33,12 +33,11 @@ public class SessionActiveHandler extends SessionHandler {
private final Zone zone;
@Inject
- public SessionActiveHandler(Executor executor,
- AccessLog accessLog,
+ public SessionActiveHandler(SessionHandler.Context ctx,
+ ApplicationRepository applicationRepository,
Tenants tenants,
- Zone zone,
- ApplicationRepository applicationRepository) {
- super(executor, accessLog, applicationRepository);
+ Zone zone) {
+ super(ctx, applicationRepository);
this.tenants = tenants;
this.zone = zone;
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandler.java
index c9d5407e0e3..524eb01e625 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandler.java
@@ -28,11 +28,11 @@ public class SessionContentHandler extends SessionHandler {
private final ContentHandler contentHandler = new ContentHandler();
@Inject
- public SessionContentHandler(Executor executor,
- AccessLog accessLog,
- Tenants tenants,
- ApplicationRepository applicationRepository) {
- super(executor, accessLog, applicationRepository);
+ public SessionContentHandler(SessionHandler.Context ctx,
+ ApplicationRepository applicationRepository,
+ Tenants tenants)
+ {
+ super(ctx, applicationRepository);
this.tenants = tenants;
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandler.java
index 5908851e399..b0c251f477c 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandler.java
@@ -49,12 +49,11 @@ public class SessionCreateHandler extends SessionHandler {
private final Duration zookeeperBarrierTimeout;
@Inject
- public SessionCreateHandler(Executor executor,
- AccessLog accessLog,
+ public SessionCreateHandler(SessionHandler.Context ctx,
+ ApplicationRepository applicationRepository,
Tenants tenants,
- ConfigserverConfig configserverConfig,
- ApplicationRepository applicationRepository) {
- super(executor, accessLog, applicationRepository);
+ ConfigserverConfig configserverConfig) {
+ super(ctx, applicationRepository);
this.tenants = tenants;
this.zookeeperBarrierTimeout = Duration.ofSeconds(configserverConfig.zookeeper().barrierTimeout());
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandler.java
index 03a3f3556e4..2b432a50ee1 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandler.java
@@ -41,12 +41,11 @@ public class SessionPrepareHandler extends SessionHandler {
private final Duration zookeeperBarrierTimeout;
@Inject
- public SessionPrepareHandler(Executor executor,
- AccessLog accessLog,
+ public SessionPrepareHandler(SessionHandler.Context ctx,
+ ApplicationRepository applicationRepository,
Tenants tenants,
- ConfigserverConfig configserverConfig,
- ApplicationRepository applicationRepository) {
- super(executor, accessLog, applicationRepository);
+ ConfigserverConfig configserverConfig) {
+ super(ctx, applicationRepository);
this.tenants = tenants;
this.zookeeperBarrierTimeout = Duration.ofSeconds(configserverConfig.zookeeper().barrierTimeout());
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/TenantHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/TenantHandler.java
index 5c1d8a36f6a..955bba5f5b4 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/TenantHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/TenantHandler.java
@@ -3,6 +3,7 @@ package com.yahoo.vespa.config.server.http.v2;
import java.util.List;
import java.util.concurrent.Executor;
+import com.google.inject.Inject;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.TenantName;
@@ -29,8 +30,10 @@ public class TenantHandler extends HttpHandler {
private static final String TENANT_NAME_REGEXP = "[\\w-]+";
private final Tenants tenants;
- public TenantHandler(Executor executor, AccessLog accessLog, Tenants tenants) {
- super(executor, accessLog);
+ @Inject
+ public TenantHandler(HttpHandler.Context ctx,
+ Tenants tenants) {
+ super(ctx);
this.tenants = tenants;
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java
index 99a34a45a2f..243c47ba3d7 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java
@@ -3,6 +3,8 @@ package com.yahoo.vespa.config.server.session;
import com.google.inject.Inject;
import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.jrt.Supervisor;
+import com.yahoo.jrt.Transport;
import com.yahoo.vespa.config.server.filedistribution.FileDistributionLock;
import com.yahoo.vespa.config.server.filedistribution.FileDistributionProvider;
import com.yahoo.vespa.curator.Curator;
@@ -21,6 +23,7 @@ public class FileDistributionFactory {
private static final String lockPath = "/vespa/filedistribution/lock";
private final String zkSpec;
private final Lock lock;
+ private final Supervisor supervisor = new Supervisor(new Transport());
@Inject
public FileDistributionFactory(Curator curator) {
@@ -33,7 +36,12 @@ public class FileDistributionFactory {
}
public FileDistributionProvider createProvider(File applicationPackage, ApplicationId applicationId, boolean disableFileDistributor) {
- return new FileDistributionProvider(applicationPackage, zkSpec, applicationId.serializedForm(), lock, disableFileDistributor);
+ return new FileDistributionProvider(supervisor, applicationPackage, zkSpec, applicationId.serializedForm(), lock, disableFileDistributor);
}
+ @Override
+ protected void finalize() throws Throwable {
+ super.finalize();
+ supervisor.transport().shutdown().join();
+ }
}
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandlerTest.java
index 71f4e4add50..b19d6e2e257 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandlerTest.java
@@ -44,13 +44,9 @@ public class HttpGetConfigHandlerTest {
mockRequestHandler.setAllConfigs(new HashSet<ConfigKey<?>>() {{
add(new ConfigKey<>("bar", "myid", "foo"));
}} );
- handler = new HttpGetConfigHandler(new Executor() {
- @SuppressWarnings("NullableProblems")
- @Override
- public void execute(Runnable command) {
- command.run();
- }
- }, mockRequestHandler, AccessLog.voidAccessLog());
+ handler = new HttpGetConfigHandler(
+ HttpGetConfigHandler.testOnlyContext(),
+ mockRequestHandler);
}
@Test
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpHandlerTest.java
index 76844bb7c21..bf881e7a546 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpHandlerTest.java
@@ -25,7 +25,7 @@ public class HttpHandlerTest {
@Test
public void testResponse() throws IOException {
final String message = "failed";
- HttpHandler httpHandler = new HttpTestHandler(Executors.newSingleThreadExecutor(), AccessLog.voidAccessLog(), new InvalidApplicationException(message));
+ HttpHandler httpHandler = new HttpTestHandler(new InvalidApplicationException(message));
HttpResponse response = httpHandler.handle(HttpRequest.createTestRequest("foo", com.yahoo.jdisc.http.HttpRequest.Method.GET));
assertThat(response.getStatus(), is(Response.Status.BAD_REQUEST));
ByteArrayOutputStream baos = new ByteArrayOutputStream();
@@ -38,8 +38,8 @@ public class HttpHandlerTest {
private static class HttpTestHandler extends HttpHandler {
private RuntimeException exception;
- public HttpTestHandler(Executor executor, AccessLog accessLog, RuntimeException exception) {
- super(executor, accessLog);
+ public HttpTestHandler(RuntimeException exception) {
+ super(HttpHandler.testOnlyContext());
this.exception = exception;
}
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandlerTest.java
index db8526150bf..01618e5a85f 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandlerTest.java
@@ -37,18 +37,9 @@ public class HttpListConfigsHandlerTest {
mockRequestHandler.setAllConfigs(new HashSet<ConfigKey<?>>() {{
add(new ConfigKey<>("bar", "conf/id/", "foo"));
}} );
- handler = new HttpListConfigsHandler(new Executor() {
- @Override
- public void execute(Runnable command) {
- command.run();
- }
- }, AccessLog.voidAccessLog(), mockRequestHandler);
- namedHandler = new HttpListNamedConfigsHandler(new Executor() {
- @Override
- public void execute(Runnable command) {
- command.run();
- }
- }, mockRequestHandler, AccessLog.voidAccessLog());
+ HttpListConfigsHandler.Context ctx = HttpListConfigsHandler.testOnlyContext();
+ handler = new HttpListConfigsHandler(ctx, mockRequestHandler);
+ namedHandler = new HttpListNamedConfigsHandler(ctx, mockRequestHandler);
}
@Test
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/SessionExampleHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/SessionExampleHandlerTest.java
index 7aff8f9410b..b6d9ab5d618 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/SessionExampleHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/SessionExampleHandlerTest.java
@@ -54,7 +54,7 @@ public class SessionExampleHandlerTest {
public static class SessionExampleHandler extends ThreadedHttpRequestHandler {
public SessionExampleHandler(Executor executor) {
- super(executor);
+ super(executor, null);
}
@Override
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java
index a17d485a425..c34dbe76a43 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java
@@ -52,8 +52,7 @@ public class ApplicationContentHandlerTest extends ContentHandlerTestBase {
testTenantBuilder.tenants().get(tenant2).getLocalSessionRepo().addSession(new MockSession(3l, FilesApplicationPackage.fromFile(new File("src/test/apps/content2"))));
testTenantBuilder.tenants().get(tenant1).getApplicationRepo().createPutApplicationTransaction(idTenant1, 2l).commit();
testTenantBuilder.tenants().get(tenant2).getApplicationRepo().createPutApplicationTransaction(idTenant2, 3l).commit();
- handler = new ApplicationHandler(Runnable::run,
- AccessLog.voidAccessLog(),
+ handler = new ApplicationHandler(ApplicationHandler.testOnlyContext(),
Zone.defaultZone(),
new ApplicationRepository(testTenantBuilder.createTenants(),
new MockProvisioner(),
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java
index 5552758a0a6..8ac64e5b28a 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java
@@ -96,7 +96,8 @@ public class ApplicationHandlerTest {
mockHttpProxy,
new MockLogServerLogGrabber());
listApplicationsHandler = new ListApplicationsHandler(
- Runnable::run, AccessLog.voidAccessLog(), tenants, Zone.defaultZone());
+ ListApplicationsHandler.testOnlyContext(),
+ tenants, Zone.defaultZone());
}
private ApplicationHandler createMockApplicationHandler(
@@ -105,8 +106,7 @@ public class ApplicationHandlerTest {
HttpProxy httpProxy,
LogServerLogGrabber logServerLogGrabber) {
return new ApplicationHandler(
- Runnable::run,
- AccessLog.voidAccessLog(),
+ ApplicationHandler.testOnlyContext(),
Zone.defaultZone(),
new ApplicationRepository(tenants,
HostProvisionerProvider.withProvisioner(provisioner),
@@ -118,8 +118,7 @@ public class ApplicationHandlerTest {
private ApplicationHandler createApplicationHandler(Tenants tenants) {
return new ApplicationHandler(
- Runnable::run,
- AccessLog.voidAccessLog(),
+ ApplicationHandler.testOnlyContext(),
Zone.defaultZone(),
new ApplicationRepository(tenants,
HostProvisionerProvider.withProvisioner(provisioner),
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HostHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HostHandlerTest.java
index e439f424c45..fe25170d8ba 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HostHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HostHandlerTest.java
@@ -52,9 +52,9 @@ public class HostHandlerTest {
hostRegistries = testComponentRegistry.getHostRegistries();
hostRegistries.createApplicationHostRegistry(mytenant).update(ApplicationId.from(mytenant, ApplicationName.defaultName(), InstanceName.defaultName()), Collections.singletonList(hostname));
hostRegistries.getTenantHostRegistry().update(mytenant, Collections.singletonList(hostname));
- hostHandler = new HostHandler(command -> {
- command.run();
- }, AccessLog.voidAccessLog(), testComponentRegistry);
+ hostHandler = new HostHandler(
+ HostHandler.testOnlyContext(),
+ testComponentRegistry);
return hostHandler;
}
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandlerTest.java
index cc18e279002..11bacc30b27 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandlerTest.java
@@ -49,9 +49,9 @@ public class HttpGetConfigHandlerTest {
TestTenantBuilder tb = new TestTenantBuilder();
tb.createTenant(tenant).withRequestHandler(mockRequestHandler).build();
Tenants tenants = tb.createTenants();
- handler = new HttpGetConfigHandler(command -> {
- command.run();
- }, AccessLog.voidAccessLog(), tenants);
+ handler = new HttpGetConfigHandler(
+ HttpGetConfigHandler.testOnlyContext(),
+ tenants);
}
@Test
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandlerTest.java
index a66e9542a5f..e7ccd9f957e 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandlerTest.java
@@ -45,12 +45,12 @@ public class HttpListConfigsHandlerTest {
TestTenantBuilder tb = new TestTenantBuilder();
tb.createTenant(TenantName.from("mytenant")).withRequestHandler(mockRequestHandler).build();
Tenants tenants = tb.createTenants();
- handler = new HttpListConfigsHandler(command -> {
- command.run();
- }, AccessLog.voidAccessLog(), tenants, Zone.defaultZone());
- namedHandler = new HttpListNamedConfigsHandler(command -> {
- command.run();
- }, AccessLog.voidAccessLog(), tenants, Zone.defaultZone());
+ handler = new HttpListConfigsHandler(
+ HttpListConfigsHandler.testOnlyContext(),
+ tenants, Zone.defaultZone());
+ namedHandler = new HttpListNamedConfigsHandler(
+ HttpListConfigsHandler.testOnlyContext(),
+ tenants, Zone.defaultZone());
}
@Test
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java
index 9e7853a8fdf..3233d9598d1 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java
@@ -39,10 +39,10 @@ public class ListApplicationsHandlerTest {
applicationRepo = testBuilder.tenants().get(mytenant).getApplicationRepo();
applicationRepo2 = testBuilder.tenants().get(foobar).getApplicationRepo();
Tenants tenants = testBuilder.createTenants();
- handler = new ListApplicationsHandler(Runnable::run,
- AccessLog.voidAccessLog(),
- tenants,
- new Zone(Environment.dev, RegionName.from("us-east")));
+ handler = new ListApplicationsHandler(
+ ListApplicationsHandler.testOnlyContext(),
+ tenants,
+ new Zone(Environment.dev, RegionName.from("us-east")));
}
@Test
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java
index 6542c865d56..04bc8d7b49a 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java
@@ -373,13 +373,12 @@ public class SessionActiveHandlerTest extends SessionHandlerTest {
.withApplicationRepo(applicationRepo)
.build();
return new SessionActiveHandler(
- Runnable::run,
- AccessLog.voidAccessLog(),
- testTenantBuilder.createTenants(),
- Zone.defaultZone(),
+ SessionActiveHandler.testOnlyContext(),
new ApplicationRepository(testTenantBuilder.createTenants(),
hostProvisioner,
- Clock.systemUTC()));
+ Clock.systemUTC()),
+ testTenantBuilder.createTenants(),
+ Zone.defaultZone());
}
}
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java
index 1d831032416..e4841930cc8 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java
@@ -161,15 +161,11 @@ public class SessionContentHandlerTest extends ContentHandlerTestBase {
private SessionContentHandler createHandler() throws Exception {
TestTenantBuilder testTenantBuilder = new TestTenantBuilder();
testTenantBuilder.createTenant(tenant).getLocalSessionRepo().addSession(new MockSession(1l, FilesApplicationPackage.fromFile(createTestApp())));
- return new SessionContentHandler(new Executor() {
- @SuppressWarnings("NullableProblems")
- @Override
- public void execute(Runnable command) {
- command.run();
- }
- }, AccessLog.voidAccessLog(), testTenantBuilder.createTenants(),
- new ApplicationRepository(testTenantBuilder.createTenants(),
- new SessionHandlerTest.MockProvisioner(),
- Clock.systemUTC()));
+ return new SessionContentHandler(
+ SessionContentHandler.testOnlyContext(),
+ new ApplicationRepository(testTenantBuilder.createTenants(),
+ new SessionHandlerTest.MockProvisioner(),
+ Clock.systemUTC()),
+ testTenantBuilder.createTenants());
}
}
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java
index 65b12490b17..fc9264a6ef5 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java
@@ -243,10 +243,13 @@ public class SessionCreateHandlerTest extends SessionHandlerTest {
private SessionCreateHandler createHandler(Tenants tenants) throws Exception {
TestTenantBuilder testTenantBuilder = new TestTenantBuilder();
final ConfigserverConfig configserverConfig = new ConfigserverConfig(new ConfigserverConfig.Builder());
- return new SessionCreateHandler(Runnable::run, AccessLog.voidAccessLog(), tenants, configserverConfig,
- new ApplicationRepository(testTenantBuilder.createTenants(),
- new SessionHandlerTest.MockProvisioner(),
- Clock.systemUTC()));
+ return new SessionCreateHandler(
+ SessionCreateHandler.testOnlyContext(),
+ new ApplicationRepository(testTenantBuilder.createTenants(),
+ new SessionHandlerTest.MockProvisioner(),
+ Clock.systemUTC()),
+ tenants, configserverConfig);
+
}
private HttpRequest post() throws FileNotFoundException {
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java
index 74a2dcf8054..1759cd68062 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java
@@ -383,10 +383,13 @@ public class SessionPrepareHandlerTest extends SessionHandlerTest {
private SessionHandler createHandler(TestTenantBuilder builder) {
final ConfigserverConfig configserverConfig = new ConfigserverConfig(new ConfigserverConfig.Builder());
- return new SessionPrepareHandler(Runnable::run, AccessLog.voidAccessLog(), builder.createTenants(), configserverConfig,
- new ApplicationRepository(builder.createTenants(),
- new MockProvisioner(),
- Clock.systemUTC()));
+ return new SessionPrepareHandler(
+ SessionPrepareHandler.testOnlyContext(),
+ new ApplicationRepository(builder.createTenants(),
+ new MockProvisioner(),
+ Clock.systemUTC()),
+ builder.createTenants(), configserverConfig);
+
}
private TestTenantBuilder addTenant(TenantName tenantName,
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/TenantHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/TenantHandlerTest.java
index ce4b25fe529..e948bf68970 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/TenantHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/TenantHandlerTest.java
@@ -27,7 +27,9 @@ public class TenantHandlerTest extends TenantTest {
@Before
public void setup() throws Exception {
- handler = new TenantHandler(testExecutor(), null, tenants);
+ handler = new TenantHandler(
+ TenantHandler.testOnlyContext(),
+ tenants);
}
@Test
diff --git a/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java b/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java
index 85026296363..1bbc08aa0a7 100644
--- a/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java
+++ b/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java
@@ -12,7 +12,6 @@ import java.math.BigDecimal;
import java.math.RoundingMode;
import java.net.URI;
import java.security.Principal;
-import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -49,8 +48,6 @@ public class JSONFormatter {
generator.writeStartObject();
generator.writeStringField("ip", accessLogEntry.getIpV4Address());
generator.writeNumberField("time", toTimestampInSeconds(accessLogEntry.getTimeStampMillis()));
- generator.writeStringField("time-iso8601",
- Instant.ofEpochMilli(accessLogEntry.getTimeStampMillis()).toString());
generator.writeNumberField("duration",
durationAsSeconds(accessLogEntry.getDurationBetweenRequestResponseMillis()));
generator.writeNumberField("responsesize", accessLogEntry.getReturnedContentSize());
diff --git a/container-accesslogging/src/test/java/com/yahoo/container/logging/JSONLogTestCase.java b/container-accesslogging/src/test/java/com/yahoo/container/logging/JSONLogTestCase.java
index 7f81a3568dd..ae27d7b1814 100644
--- a/container-accesslogging/src/test/java/com/yahoo/container/logging/JSONLogTestCase.java
+++ b/container-accesslogging/src/test/java/com/yahoo/container/logging/JSONLogTestCase.java
@@ -40,7 +40,6 @@ public class JSONLogTestCase extends junit.framework.TestCase {
String expectedOutput =
"{\"ip\":\"152.200.54.243\"," +
"\"time\":920880005.023," +
- "\"time-iso8601\":\"1999-03-08T08:00:05.023Z\"," +
"\"duration\":0.122," +
"\"responsesize\":9875," +
"\"code\":200," +
@@ -69,7 +68,6 @@ public class JSONLogTestCase extends junit.framework.TestCase {
String expectedOutput =
"{\"ip\":\"152.200.54.243\"," +
"\"time\":920880005.023," +
- "\"time-iso8601\":\"1999-03-08T08:00:05.023Z\"," +
"\"duration\":0.122," +
"\"responsesize\":9875," +
"\"code\":200," +
@@ -102,7 +100,6 @@ public class JSONLogTestCase extends junit.framework.TestCase {
String expectedOutput =
"{\"ip\":\"152.200.54.243\"," +
"\"time\":920880005.023," +
- "\"time-iso8601\":\"1999-03-08T08:00:05.023Z\"," +
"\"duration\":0.122," +
"\"responsesize\":9875," +
"\"code\":200," +
@@ -128,7 +125,6 @@ public class JSONLogTestCase extends junit.framework.TestCase {
expectedOutput =
"{\"ip\":\"152.200.54.243\"," +
"\"time\":920880005.023," +
- "\"time-iso8601\":\"1999-03-08T08:00:05.023Z\"," +
"\"duration\":0.122," +
"\"responsesize\":9875," +
"\"code\":200," +
@@ -175,7 +171,6 @@ public class JSONLogTestCase extends junit.framework.TestCase {
String expectedOutput =
"{\"ip\":\"152.200.54.243\"," +
"\"time\":920880005.023," +
- "\"time-iso8601\":\"1999-03-08T08:00:05.023Z\"," +
"\"duration\":0.122," +
"\"responsesize\":9875," +
"\"code\":200," +
diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java b/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java
index 0095fcece4f..4f365ebbab3 100644
--- a/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java
+++ b/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java
@@ -35,7 +35,42 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler {
this(executor, accessLog, null);
}
+ public static class Context {
+ final Executor executor;
+ final AccessLog accessLog;
+ final Metric metric;
+ @Inject
+ public Context(Executor executor, AccessLog accessLog, Metric metric) {
+ this.executor = executor;
+ this.accessLog = accessLog;
+ this.metric = metric;
+ }
+ public Context(Context other) {
+ this.executor = other.executor;
+ this.accessLog = other.accessLog;
+ this.metric = other.metric;
+ }
+ }
+ public static Context testOnlyContext() {
+ return new Context(new Executor() {
+ @Override
+ public void execute(Runnable command) {
+ command.run();
+ }
+ },
+ AccessLog.voidAccessLog(),
+ null);
+ }
+
@Inject
+ public LoggingRequestHandler(Context ctx) {
+ this(ctx.executor, ctx.accessLog, ctx.metric);
+ }
+
+ public LoggingRequestHandler(Context ctx, boolean allowAsyncResponse) {
+ this(ctx.executor, ctx.accessLog, ctx.metric, allowAsyncResponse);
+ }
+
public LoggingRequestHandler(Executor executor, AccessLog accessLog, Metric metric) {
this(executor, accessLog, metric, false);
}
diff --git a/container-dev/pom.xml b/container-dev/pom.xml
index f62bbd22690..16006452e61 100644
--- a/container-dev/pom.xml
+++ b/container-dev/pom.xml
@@ -121,6 +121,18 @@
<groupId>org.bouncycastle</groupId>
<artifactId>bcprov-jdk15on</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>tensorflow</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
@@ -189,6 +201,18 @@
<groupId>xerces</groupId>
<artifactId>xercesImpl</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>tensorflow</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/ConfiguredApplication.java b/container-disc/src/main/java/com/yahoo/container/jdisc/ConfiguredApplication.java
index fa2ee8e89a9..bf696771b20 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/ConfiguredApplication.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/ConfiguredApplication.java
@@ -29,6 +29,7 @@ import com.yahoo.jdisc.handler.RequestHandler;
import com.yahoo.jdisc.service.ClientProvider;
import com.yahoo.jdisc.service.ServerProvider;
import com.yahoo.jrt.ListenFailedException;
+import com.yahoo.log.LogLevel;
import com.yahoo.log.LogSetup;
import com.yahoo.osgi.OsgiImpl;
import com.yahoo.vespa.config.ConfigKey;
@@ -88,6 +89,7 @@ public final class ConfiguredApplication implements Application {
static {
LogSetup.initVespaLogging("Container");
+ log.log(LogLevel.INFO, "Starting container");
}
/**
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java
index 033b396bc9b..c4c57f4bc47 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.container.jdisc.athenz;
+import javax.net.ssl.SSLContext;
+
/**
* @author mortent
*/
@@ -8,4 +10,5 @@ public interface AthenzIdentityProvider {
String getNToken() throws AthenzIdentityProviderException;
String getDomain();
String getService();
+ SSLContext getSslContext();
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java
index 356780a0900..3d6b32744c6 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java
@@ -8,6 +8,20 @@ import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider;
import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException;
import com.yahoo.log.LogLevel;
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.TrustManager;
+import javax.net.ssl.TrustManagerFactory;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.security.KeyManagementException;
+import java.security.KeyStore;
+import java.security.KeyStoreException;
+import java.security.NoSuchAlgorithmException;
+import java.security.UnrecoverableKeyException;
+import java.security.cert.Certificate;
+import java.security.cert.CertificateException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
@@ -106,6 +120,52 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen
}
@Override
+ public SSLContext getSslContext() {
+ try {
+ SSLContext sslContext = SSLContext.getInstance("TLSv1.2");
+ sslContext.init(createKeyManagersWithServiceCertificate(),
+ createTrustManagersWithAthenzCa(),
+ null);
+ return sslContext;
+ } catch (NoSuchAlgorithmException | KeyManagementException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private KeyManager[] createKeyManagersWithServiceCertificate() {
+ try {
+ credentialsRetrievedSignal.await();
+ KeyStore keyStore = KeyStore.getInstance("JKS");
+ keyStore.load(null);
+ keyStore.setKeyEntry("instance-key",
+ credentials.get().getKeyPair().getPrivate(),
+ new char[0],
+ new Certificate[]{credentials.get().getCertificate()});
+ KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+ keyManagerFactory.init(keyStore, new char[0]);
+ return keyManagerFactory.getKeyManagers();
+ } catch (KeyStoreException | NoSuchAlgorithmException | UnrecoverableKeyException | CertificateException | IOException e) {
+ throw new RuntimeException(e);
+ } catch (InterruptedException e) {
+ throw new AthenzIdentityProviderException("Failed to register instance credentials", lastThrowable.get());
+ }
+ }
+
+ private static TrustManager[] createTrustManagersWithAthenzCa() {
+ try {
+ KeyStore trustStore = KeyStore.getInstance("JKS");
+ try (FileInputStream in = new FileInputStream("/home/y/share/ssl/certs/yahoo_certificate_bundle.jks")) {
+ trustStore.load(in, null);
+ }
+ TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+ trustManagerFactory.init(trustStore);
+ return trustManagerFactory.getTrustManagers();
+ } catch (CertificateException | IOException | KeyStoreException | NoSuchAlgorithmException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
public void deconstruct() {
scheduler.shutdown(AWAIT_TERMINTATION_TIMEOUT);
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/component/Deconstructor.java b/container-disc/src/main/java/com/yahoo/container/jdisc/component/Deconstructor.java
index a4eb2449064..b83dd6175e1 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/component/Deconstructor.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/component/Deconstructor.java
@@ -37,7 +37,6 @@ public class Deconstructor implements ComponentDeconstructor {
if (component instanceof AbstractComponent) {
AbstractComponent abstractComponent = (AbstractComponent) component;
if (abstractComponent.isDeconstructable()) {
- log.info("Scheduling deconstruction of " + abstractComponent);
executor.schedule(new DestructComponentTask(abstractComponent), delay, TimeUnit.SECONDS);
}
} else if (component instanceof Provider) {
diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java
index fc1bbace092..1e44a8fa64d 100644
--- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java
+++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java
@@ -1,16 +1,16 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.prelude.fastsearch;
+import com.yahoo.container.search.LegacyEmulationConfig;
+import com.yahoo.data.access.Inspector;
+import com.yahoo.log.LogLevel;
+
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
-import com.yahoo.data.access.Inspector;
-import com.yahoo.container.search.LegacyEmulationConfig;
-
-import com.yahoo.log.LogLevel;
/**
* @author Bjørn Borud
@@ -25,7 +25,7 @@ public abstract class DocsumField {
Map<String, Constructor<? extends DocsumField>> constructors = new HashMap<>();
- void put(String typename, Class<? extends DocsumField> fieldClass)
+ void put(String typename, Class<? extends DocsumField> fieldClass)
throws NoSuchMethodException, SecurityException {
Constructor<? extends DocsumField> constructor = fieldClass.getConstructor(String.class);
constructors.put(typename, constructor);
@@ -106,7 +106,7 @@ public abstract class DocsumField {
public abstract Object decode(ByteBuffer b);
/**
- * Get the number of bytes this field occupies in the given buffer
+ * Get the number of bytes this field occupies in the given buffer
* AND SET(!) the position to the first byte after this field.
*/
public abstract int getLength(ByteBuffer b);
diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java
index 1524a4da426..692e93bed7e 100644
--- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java
+++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java
@@ -109,7 +109,7 @@ public class FastHit extends Hit {
/**
* Returns the explicitly set uri if available, returns "index:[source]/[partid]/[id]" otherwise
- *
+ *
* @return uri of hit
*/
public URI getUri() {
@@ -128,9 +128,9 @@ public class FastHit extends Hit {
}
/**
- * The uri of the index location of this hit ("index:[source]/[partid]/[id]").
+ * The uri of the index location of this hit ("index:[source]/[partid]/[id]").
* This is the uri if no other uri is assigned
- *
+ *
* @return uri to the index.
*/
public URI getIndexUri() {
@@ -215,7 +215,7 @@ public class FastHit extends Hit {
* The empty string ("") if no value is assigned in the document.
*
* <li><b>Dynamic summary string fields</b>: A Java String before JuniperSearcher and a HitField after.</li>
- *
+ *
* <li><b>Numerics</b>: The corresponding numeric Java type.<br>
* If the field has <i>no value</i> assigned in the document,
* the special numeric {@link com.yahoo.search.result.NanNumber#NaN} is returned.
diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java
index e0ca7fbe6e1..d8b38667224 100644
--- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java
+++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java
@@ -13,7 +13,7 @@ import java.util.Optional;
/**
* A tensor field. Tensors are encoded as a data field where the data (following the length)
* is encoded in a tensor binary format defined by com.yahoo.tensor.serialization.TypedBinaryFormat
- *
+ *
* @author bratseth
*/
public class TensorField extends DocsumField implements VariableLengthField {
diff --git a/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java b/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java
index 15a8a670a2e..8091397237d 100644
--- a/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java
+++ b/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java
@@ -25,7 +25,7 @@ import static com.yahoo.prelude.querytransform.NormalizingSearcher.ACCENT_REMOVA
* Check sorting specification makes sense to the search cluster before
* passing it on to the backend.
*
- * @author <a href="mailto:steinar@yahoo-inc.com">Steinar Knutsen</a>
+ * @author Steinar Knutsen
*/
@Before(PhaseNames.BACKEND)
@After(ACCENT_REMOVAL)
@@ -118,6 +118,7 @@ public class ValidateSortingSearcher extends Searcher {
for (Sorting.FieldOrder f : l) {
String name = f.getFieldName();
if ("[rank]".equals(name) || "[docid]".equals(name)) {
+ // built-in constants - ok
} else if (names.containsKey(name)) {
AttributesConfig.Attribute attrConfig = names.get(name);
if (attrConfig != null) {
@@ -166,18 +167,13 @@ public class ValidateSortingSearcher extends Searcher {
locale = "en_US";
}
- // getLogger().info("locale = " + locale + " attrConfig.sortlocale.value() = " + attrConfig.sortlocale.value() + " query.getLanguage() = " + query.getModel().getLanguage());
- // getLogger().info("locale = " + locale);
-
Sorting.UcaSorter.Strength strength = sorter.getStrength();
if (sorter.getStrength() == Sorting.UcaSorter.Strength.UNDEFINED) {
strength = config2Strength(attrConfig.sortstrength());
}
if ((sorter.getStrength() == Sorting.UcaSorter.Strength.UNDEFINED) || (sorter.getLocale() == null) || sorter.getLocale().isEmpty()) {
- // getLogger().info("locale = " + locale + " strength = " + strength.toString());
sorter.setLocale(locale, strength);
}
- //getLogger().info("locale = " + locale + " strength = " + strength.toString() + "decompose = " + sorter.getDecomposition());
}
} else {
return ErrorMessage.createInvalidQueryParameter("The cluster " + getClusterName() + " has attribute config for field: " + name);
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
index 0ec15b95b0d..0fd529bf262 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
@@ -16,7 +16,7 @@ import java.util.Optional;
public class TensorFieldType extends FieldType {
// TODO: Require tensor type
-
+
private final Optional<TensorType> type;
/** Creates a tensor field type with optional information about the kind of tensor this will hold */
diff --git a/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java b/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java
index 15a0fd60511..5494d1965f8 100644
--- a/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java
+++ b/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java
@@ -102,7 +102,7 @@ public class SlimeSummaryTestCase {
public void testDecoding() {
Tensor tensor1 = Tensor.from("tensor(x{},y{}):{{x:foo,y:bar}:0.1}");
Tensor tensor2 = Tensor.from("tensor(x[],y[1]):{{x:0,y:0}:-0.3}");
-
+
String summary_cf = "file:src/test/java/com/yahoo/prelude/fastsearch/summary.cfg";
DocsumDefinitionSet set = createDocsumDefinitionSet(summary_cf);
byte[] docsum = makeDocsum(tensor1, tensor2);
diff --git a/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java b/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java
index e59c03b33c3..62eacaa0afe 100644
--- a/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java
@@ -2,8 +2,6 @@
package com.yahoo.search.test;
import com.yahoo.component.chain.Chain;
-import com.yahoo.language.Language;
-import com.yahoo.language.Linguistics;
import com.yahoo.language.detect.Detection;
import com.yahoo.language.detect.Detector;
import com.yahoo.language.detect.Hint;
@@ -28,7 +26,6 @@ import com.yahoo.search.query.profile.QueryProfile;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.search.result.Hit;
import com.yahoo.search.searchchain.Execution;
-
import com.yahoo.yolean.Exceptions;
import org.junit.Ignore;
import org.junit.Test;
@@ -45,14 +42,14 @@ import java.util.stream.Collectors;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.is;
-import static org.junit.Assert.assertNotEquals;
-import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
@@ -69,7 +66,7 @@ public class QueryTestCase {
assertEquals("", q.properties().get("aParameter"));
assertNull(q.properties().get("notSetParameter"));
}
-
+
// TODO: YQL work in progress (jon)
@Ignore
@Test
@@ -693,7 +690,7 @@ public class QueryTestCase {
List<IndexedItem> l = QueryTree.getPositiveTerms(i);
assertEquals(3, l.size());
}
-
+
@Test
public void testHeuristicLanguageDetectionTextExtraction() {
assertDetectionText("b ", "a:b", "text:a", "text:default");
@@ -720,27 +717,27 @@ public class QueryTestCase {
q.getModel().getQueryTree(); // cause parsing
assertEquals(expectedDetectionText, mockLinguistics.detector.lastDetectionText);
}
-
+
/** A linguistics instance which records the last language detection text passed to it */
private static class MockLinguistics extends SimpleLinguistics {
final MockDetector detector = new MockDetector();
-
+
@Override
public Detector getDetector() { return detector; }
-
+
}
-
+
private static class MockDetector extends SimpleDetector {
String lastDetectionText = null;
-
+
@Override
public Detection detect(String input, Hint hint) {
lastDetectionText = input;
return super.detect(input, hint);
}
-
+
}
protected boolean contains(String lineSubstring,String[] lines) {
diff --git a/container/pom.xml b/container/pom.xml
index 3793a3508a4..d252a5eee4a 100644
--- a/container/pom.xml
+++ b/container/pom.xml
@@ -47,6 +47,18 @@
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>tensorflow</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
</exclusions>
</dependency>
</dependencies>
diff --git a/controller-api/pom.xml b/controller-api/pom.xml
index 5ef130a22ba..543ab24999d 100644
--- a/controller-api/pom.xml
+++ b/controller-api/pom.xml
@@ -18,24 +18,9 @@
<dependencies>
<!-- provided -->
-
- <dependency>
- <groupId>com.yahoo.vespa</groupId>
- <artifactId>component</artifactId>
- <scope>provided</scope>
- <version>${project.version}</version>
- </dependency>
-
- <dependency>
- <groupId>com.yahoo.vespa</groupId>
- <artifactId>annotations</artifactId>
- <scope>provided</scope>
- <version>${project.version}</version>
- </dependency>
-
<dependency>
<groupId>com.yahoo.vespa</groupId>
- <artifactId>vespajlib</artifactId>
+ <artifactId>container-dev</artifactId>
<scope>provided</scope>
<version>${project.version}</version>
</dependency>
@@ -54,56 +39,6 @@
<version>${project.version}</version>
</dependency>
- <dependency>
- <groupId>com.fasterxml.jackson.core</groupId>
- <artifactId>jackson-annotations</artifactId>
- <scope>provided</scope>
- </dependency>
-
- <dependency>
- <groupId>com.fasterxml.jackson.core</groupId>
- <artifactId>jackson-databind</artifactId>
- <scope>provided</scope>
- </dependency>
-
- <dependency>
- <groupId>com.fasterxml.jackson.datatype</groupId>
- <artifactId>jackson-datatype-jdk8</artifactId>
- <scope>provided</scope>
- </dependency>
-
- <dependency>
- <groupId>org.glassfish.jersey.media</groupId>
- <artifactId>jersey-media-multipart</artifactId>
- <scope>provided</scope>
- </dependency>
-
- <dependency>
- <groupId>javax.servlet</groupId>
- <artifactId>javax.servlet-api</artifactId>
- <scope>provided</scope>
- </dependency>
-
- <dependency>
- <groupId>javax.ws.rs</groupId>
- <artifactId>javax.ws.rs-api</artifactId>
- <scope>provided</scope>
- </dependency>
-
- <dependency>
- <groupId>org.glassfish.jersey.core</groupId>
- <artifactId>jersey-server</artifactId>
- <version>${jersey2.version}</version>
- <scope>provided</scope>
- </dependency>
-
- <dependency>
- <groupId>com.google.inject</groupId>
- <artifactId>guice</artifactId>
- <classifier>no_aop</classifier>
- <scope>provided</scope>
- </dependency>
-
<!-- compile -->
<dependency>
@@ -128,6 +63,19 @@
<scope>test</scope>
</dependency>
+ <!-- Required for AthenzIdentityVerifierTest -->
+ <dependency>
+ <groupId>org.bouncycastle</groupId>
+ <artifactId>bcpkix-jdk15on</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+
</dependencies>
<build>
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifier.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifier.java
new file mode 100644
index 00000000000..bfaa6c2acda
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifier.java
@@ -0,0 +1,41 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.api.integration.athenz;
+
+import javax.net.ssl.HostnameVerifier;
+import javax.net.ssl.SSLPeerUnverifiedException;
+import javax.net.ssl.SSLSession;
+import java.security.cert.X509Certificate;
+import java.util.Set;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * A {@link HostnameVerifier} that validates Athenz x509 certificates using the identity in the Common Name attribute.
+ *
+ * @author bjorncs
+ */
+// TODO Move to dedicated Athenz bundle
+public class AthenzIdentityVerifier implements HostnameVerifier {
+
+ private static final Logger log = Logger.getLogger(AthenzIdentityVerifier.class.getName());
+
+ private final Set<AthenzIdentity> allowedIdentities;
+
+ public AthenzIdentityVerifier(Set<AthenzIdentity> allowedIdentities) {
+ this.allowedIdentities = allowedIdentities;
+ }
+
+ @Override
+ public boolean verify(String hostname, SSLSession session) {
+ try {
+ X509Certificate cert = (X509Certificate) session.getPeerCertificates()[0];
+ AthenzIdentity certificateIdentity = AthenzUtils.createAthenzIdentity(cert);
+ return allowedIdentities.contains(certificateIdentity);
+ } catch (SSLPeerUnverifiedException e) {
+ log.log(Level.WARNING, "Unverified client: " + hostname);
+ return false;
+ }
+ }
+
+}
+
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzPrincipal.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzPrincipal.java
index 8279edcd8e6..b31cb4a26bb 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzPrincipal.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzPrincipal.java
@@ -5,6 +5,7 @@ import com.yahoo.vespa.hosted.controller.api.identifiers.AthenzDomain;
import java.security.Principal;
import java.util.Objects;
+import java.util.Optional;
/**
* @author bjorncs
@@ -14,6 +15,10 @@ public class AthenzPrincipal implements Principal {
private final AthenzIdentity athenzIdentity;
private final NToken nToken;
+ public AthenzPrincipal(AthenzIdentity athenzIdentity) {
+ this(athenzIdentity, null);
+ }
+
public AthenzPrincipal(AthenzIdentity athenzIdentity,
NToken nToken) {
this.athenzIdentity = athenzIdentity;
@@ -33,8 +38,8 @@ public class AthenzPrincipal implements Principal {
return athenzIdentity.getDomain();
}
- public NToken getNToken() {
- return nToken;
+ public Optional<NToken> getNToken() {
+ return Optional.ofNullable(nToken);
}
@Override
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtils.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtils.java
index 0ed5d86dd7e..04ec0b61614 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtils.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtils.java
@@ -4,6 +4,10 @@ package com.yahoo.vespa.hosted.controller.api.integration.athenz;
import com.yahoo.vespa.hosted.controller.api.identifiers.AthenzDomain;
import com.yahoo.vespa.hosted.controller.api.identifiers.UserId;
+import javax.naming.NamingException;
+import javax.naming.ldap.LdapName;
+import java.security.cert.X509Certificate;
+
/**
* @author bjorncs
*/
@@ -23,4 +27,40 @@ public class AthenzUtils {
}
}
+ public static AthenzIdentity createAthenzIdentity(String fullName) {
+ int domainIdentityNameSeparatorIndex = fullName.lastIndexOf('.');
+ if (domainIdentityNameSeparatorIndex == -1
+ || domainIdentityNameSeparatorIndex == 0
+ || domainIdentityNameSeparatorIndex == fullName.length() - 1) {
+ throw new IllegalArgumentException("Invalid Athenz identity: " + fullName);
+ }
+ AthenzDomain domain = new AthenzDomain(fullName.substring(0, domainIdentityNameSeparatorIndex));
+ String identityName = fullName.substring(domainIdentityNameSeparatorIndex + 1, fullName.length());
+ return createAthenzIdentity(domain, identityName);
+ }
+
+ public static AthenzIdentity createAthenzIdentity(X509Certificate certificate) {
+ String commonName = getCommonName(certificate);
+ if (isAthenzRoleIdentity(commonName)) {
+ throw new IllegalArgumentException("Athenz role certificate not supported");
+ }
+ return createAthenzIdentity(commonName);
+ }
+
+ private static boolean isAthenzRoleIdentity(String commonName) {
+ return commonName.contains(":role.");
+ }
+
+ private static String getCommonName(X509Certificate certificate) {
+ try {
+ String subjectPrincipal = certificate.getSubjectX500Principal().getName();
+ return new LdapName(subjectPrincipal).getRdns().stream()
+ .filter(rdn -> rdn.getType().equalsIgnoreCase("cn"))
+ .map(rdn -> rdn.getValue().toString())
+ .findFirst()
+ .orElseThrow(() -> new IllegalArgumentException("Could not find CN in certificate: " + subjectPrincipal));
+ } catch (NamingException e) {
+ throw new IllegalArgumentException("Invalid CN: " + e, e);
+ }
+ }
}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/InvalidTokenException.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/InvalidTokenException.java
index 1df1746b02e..967af1c553f 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/InvalidTokenException.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/InvalidTokenException.java
@@ -4,7 +4,7 @@ package com.yahoo.vespa.hosted.controller.api.integration.athenz;
/**
* @author bjorncs
*/
-public class InvalidTokenException extends Exception {
+public class InvalidTokenException extends RuntimeException {
public InvalidTokenException(String message) {
super(message);
}
diff --git a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifierTest.java b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifierTest.java
new file mode 100644
index 00000000000..88da28fb273
--- /dev/null
+++ b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifierTest.java
@@ -0,0 +1,82 @@
+package com.yahoo.vespa.hosted.controller.api.integration.athenz;
+
+import org.bouncycastle.asn1.x500.X500Name;
+import org.bouncycastle.asn1.x509.BasicConstraints;
+import org.bouncycastle.asn1.x509.Extension;
+import org.bouncycastle.cert.CertIOException;
+import org.bouncycastle.cert.X509v3CertificateBuilder;
+import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
+import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder;
+import org.bouncycastle.jce.provider.BouncyCastleProvider;
+import org.bouncycastle.operator.ContentSigner;
+import org.bouncycastle.operator.OperatorCreationException;
+import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
+import org.junit.Test;
+
+import javax.net.ssl.SSLPeerUnverifiedException;
+import javax.net.ssl.SSLSession;
+import java.math.BigInteger;
+import java.security.KeyPair;
+import java.security.KeyPairGenerator;
+import java.security.NoSuchAlgorithmException;
+import java.security.cert.Certificate;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Date;
+
+import static java.util.Collections.singleton;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author bjorncs
+ */
+public class AthenzIdentityVerifierTest {
+
+ @Test
+ public void verifies_certificate_with_athenz_service_as_common_name() throws Exception {
+ AthenzIdentity trustedIdentity = new AthenzService("mydomain", "alice");
+ AthenzIdentity unknownIdentity = new AthenzService("mydomain", "mallory");
+ KeyPair keyPair = createKeyPair();
+ AthenzIdentityVerifier verifier = new AthenzIdentityVerifier(singleton(trustedIdentity));
+ assertTrue(verifier.verify("hostname", createSslSessionMock(createSelfSignedCertificate(keyPair, trustedIdentity))));
+ assertFalse(verifier.verify("hostname", createSslSessionMock(createSelfSignedCertificate(keyPair, unknownIdentity))));
+ }
+
+ private static KeyPair createKeyPair() throws NoSuchAlgorithmException {
+ KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA");
+ keyGen.initialize(512);
+ return keyGen.generateKeyPair();
+ }
+
+ private static X509Certificate createSelfSignedCertificate(KeyPair keyPair, AthenzIdentity identity)
+ throws OperatorCreationException, CertIOException, CertificateException {
+ ContentSigner contentSigner = new JcaContentSignerBuilder("SHA256WithRSA").build(keyPair.getPrivate());
+ X500Name x500Name = new X500Name("CN="+ identity.getFullName());
+ Instant now = Instant.now();
+ Date notBefore = Date.from(now);
+ Date notAfter = Date.from(now.plus(Duration.ofDays(30)));
+
+ X509v3CertificateBuilder certificateBuilder =
+ new JcaX509v3CertificateBuilder(
+ x500Name, BigInteger.valueOf(now.toEpochMilli()), notBefore, notAfter, x500Name, keyPair.getPublic()
+ )
+ .addExtension(Extension.basicConstraints, true, new BasicConstraints(true));
+
+ return new JcaX509CertificateConverter()
+ .setProvider(new BouncyCastleProvider())
+ .getCertificate(certificateBuilder.build(contentSigner));
+
+ }
+
+ private static SSLSession createSslSessionMock(X509Certificate certificate) throws SSLPeerUnverifiedException {
+ SSLSession sslSession = mock(SSLSession.class);
+ when(sslSession.getPeerCertificates()).thenReturn(new Certificate[]{certificate});
+ return sslSession;
+ }
+
+} \ No newline at end of file
diff --git a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtilsTest.java b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtilsTest.java
new file mode 100644
index 00000000000..f2db74a4c3d
--- /dev/null
+++ b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtilsTest.java
@@ -0,0 +1,21 @@
+package com.yahoo.vespa.hosted.controller.api.integration.athenz;
+
+import com.yahoo.vespa.hosted.controller.api.identifiers.AthenzDomain;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bjorncs
+ */
+public class AthenzUtilsTest {
+
+ @Test
+ public void athenz_identity_is_parsed_from_dot_separated_string() {
+ AthenzIdentity expectedIdentity = new AthenzService(new AthenzDomain("my.subdomain"), "myservicename");
+ String fullName = expectedIdentity.getFullName();
+ AthenzIdentity actualIdentity = AthenzUtils.createAthenzIdentity(fullName);
+ assertEquals(expectedIdentity, actualIdentity);
+ }
+
+} \ No newline at end of file
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java
index 1b2ad9f938a..fb675862320 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java
@@ -24,6 +24,7 @@ import java.util.Objects;
* @author smorgrav
*/
public class ClusterCost {
+
private final double tco;
private final double waste;
private final ClusterInfo clusterInfo;
@@ -32,8 +33,8 @@ public class ClusterCost {
private final ClusterUtilization resultUtilization;
/**
- * @param clusterInfo Value object with cluster info e.g. the TCO for the hardware used
- * @param systemUtilization Utilization of system resources (as ratios)
+ * @param clusterInfo value object with cluster info e.g. the TCO for the hardware used
+ * @param systemUtilization utilization of system resources (as ratios)
*/
public ClusterCost(ClusterInfo clusterInfo,
ClusterUtilization systemUtilization) {
@@ -79,10 +80,10 @@ public class ClusterCost {
}
static ClusterUtilization calculateResultUtilization(ClusterUtilization system, ClusterUtilization target) {
- double cpu = ratio(system.getCpu(),target.getCpu());
- double mem = ratio(system.getMemory(),target.getMemory());
- double disk = ratio(system.getDisk(),target.getDisk());
- double diskbusy = ratio(system.getDiskBusy(),target.getDiskBusy());
+ double cpu = ratio(system.getCpu(), target.getCpu());
+ double mem = ratio(system.getMemory(), target.getMemory());
+ double disk = ratio(system.getDisk(), target.getDisk());
+ double diskbusy = ratio(system.getDiskBusy(), target.getDiskBusy());
return new ClusterUtilization(mem, cpu, disk, diskbusy);
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java
index 585690793bb..371e1c41e32 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java
@@ -44,17 +44,17 @@ public class DeploymentCost {
return clusters;
}
- /** @return Total cost of ownership for the deployment (sum of all clusters) */
+ /** Returns the total monthly cost of ownership for the deployment (sum of all clusters) */
public double getTco() {
return tco;
}
- /** @return The utilization of clusters that wastes most money in this deployment */
+ /** Returns the utilization of clusters that wastes most money in this deployment */
public double getUtilization() {
return utilization;
}
- /** @return The amount of dollars spent and not utilized */
+ /** Returns the amount of dollars spent and not utilized */
public double getWaste() {
return waste;
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobStatus.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobStatus.java
index ceb04d88026..a7940076277 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobStatus.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobStatus.java
@@ -11,14 +11,14 @@ import java.util.Optional;
/**
* The last known build status of a particular deployment job for a particular application.
* This is immutable.
- *
+ *
* @author bratseth
* @author mpolden
*/
public class JobStatus {
-
+
private final DeploymentJobs.JobType type;
-
+
private final Optional<JobRun> lastTriggered;
private final Optional<JobRun> lastCompleted;
private final Optional<JobRun> firstFailing;
@@ -42,7 +42,7 @@ public class JobStatus {
this.type = type;
this.jobError = jobError;
-
+
// Never say we triggered component because we don't:
this.lastTriggered = type == DeploymentJobs.JobType.component ? Optional.empty() : lastTriggered;
this.lastCompleted = lastCompleted;
@@ -52,7 +52,7 @@ public class JobStatus {
/** Returns an empty job status */
public static JobStatus initial(DeploymentJobs.JobType type) {
- return new JobStatus(type, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
+ return new JobStatus(type, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
}
public JobStatus withTriggering(Version version, Optional<ApplicationRevision> revision,
@@ -89,13 +89,13 @@ public class JobStatus {
Optional<JobRun> firstFailing = this.firstFailing;
if (jobError.isPresent() && ! this.firstFailing.isPresent())
firstFailing = Optional.of(thisCompletion);
-
+
Optional<JobRun> lastSuccess = this.lastSuccess;
if ( ! jobError.isPresent()) {
lastSuccess = Optional.of(thisCompletion);
firstFailing = Optional.empty();
}
-
+
return new JobStatus(type, jobError, lastTriggered, Optional.of(thisCompletion), firstFailing, lastSuccess);
}
@@ -105,7 +105,7 @@ public class JobStatus {
public boolean isSuccess() {
return lastCompleted().isPresent() && ! jobError.isPresent();
}
-
+
/** Returns true if last triggered is newer than last completed and was started after timeoutLimit */
public boolean isRunning(Instant timeoutLimit) {
if ( ! lastTriggered.isPresent()) return false;
@@ -114,6 +114,11 @@ public class JobStatus {
return ! lastTriggered.get().at().isBefore(lastCompleted.get().at());
}
+ /** Returns true if this is running and has been so since before the given limit */
+ public boolean isHanging(Instant timeoutLimit) {
+ return isRunning(Instant.MIN) && lastTriggered.get().at().isBefore(timeoutLimit.plusMillis(1));
+ }
+
/** The error of the last completion, or empty if the last run succeeded */
public Optional<DeploymentJobs.JobError> jobError() { return jobError; }
@@ -140,10 +145,10 @@ public class JobStatus {
", first failing: " + firstFailing.map(JobRun::toString).orElse("(not failing)") +
", lastSuccess: " + lastSuccess.map(JobRun::toString).orElse("(never)") + "]";
}
-
+
@Override
public int hashCode() { return Objects.hash(type, jobError, lastTriggered, lastCompleted, firstFailing, lastSuccess); }
-
+
@Override
public boolean equals(Object o) {
if (o == this) return true;
@@ -159,15 +164,15 @@ public class JobStatus {
/** Information about a particular triggering or completion of a run of a job. This is immutable. */
public static class JobRun {
-
+
private final long id;
private final Version version;
private final Optional<ApplicationRevision> revision;
private final boolean upgrade;
private final String reason;
private final Instant at;
-
- public JobRun(long id, Version version, Optional<ApplicationRevision> revision,
+
+ public JobRun(long id, Version version, Optional<ApplicationRevision> revision,
boolean upgrade, String reason, Instant at) {
Objects.requireNonNull(version, "version cannot be null");
Objects.requireNonNull(revision, "revision cannot be null");
@@ -188,16 +193,16 @@ public class JobStatus {
// TODO: Fix how this is set, and add an applicationChange() method as well, in the same vein.
/** Returns whether this job run was a Vespa upgrade */
public boolean upgrade() { return upgrade; }
-
+
/** Returns the Vespa version used on this run */
public Version version() { return version; }
-
+
/** Returns the application revision used for this run, or empty when not known */
public Optional<ApplicationRevision> revision() { return revision; }
-
+
/** Returns a human-readable reason for this particular job run */
public String reason() { return reason; }
-
+
/** Returns the time if this triggering or completion */
public Instant at() { return at; }
@@ -218,7 +223,7 @@ public class JobStatus {
public int hashCode() {
return Objects.hash(version, revision, upgrade, at);
}
-
+
@Override
public boolean equals(Object o) {
if (this == o) return true;
@@ -234,7 +239,7 @@ public class JobStatus {
@Override
public String toString() { return "job run " + id + " of version " + (upgrade() ? "upgrade " : "") + version + " "
+ revision + " at " + at; }
-
+
}
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilter.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilter.java
index 328461355db..7aaaad534db 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilter.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilter.java
@@ -7,17 +7,24 @@ import com.yahoo.jdisc.handler.ResponseHandler;
import com.yahoo.jdisc.http.filter.DiscFilterRequest;
import com.yahoo.jdisc.http.filter.SecurityRequestFilter;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzPrincipal;
-import com.yahoo.vespa.hosted.controller.api.integration.athenz.InvalidTokenException;
+import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzUtils;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.NToken;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.ZmsKeystore;
import com.yahoo.vespa.hosted.controller.athenz.config.AthenzConfig;
+import java.security.cert.X509Certificate;
+import java.util.Optional;
import java.util.concurrent.Executor;
import static com.yahoo.vespa.hosted.controller.athenz.filter.SecurityFilterUtils.sendErrorResponse;
/**
- * Performs authentication by validating the principal token (NToken) header.
+ * Authenticates Athenz principal, either through:
+ * 1. TLS client authentication (based on Athenz x509 identity certficiate).
+ * 2. The principal token (NToken) header.
+ * The TLS authentication is based on the following assumptions:
+ * - The underlying connector is configured with 'clientAuth' set to either WANT_AUTH or NEED_AUTH.
+ * - The trust store is configured with the Athenz CA certificates only.
*
* @author bjorncs
*/
@@ -43,18 +50,45 @@ public class AthenzPrincipalFilter implements SecurityRequestFilter {
@Override
public void filter(DiscFilterRequest request, ResponseHandler responseHandler) {
- String rawToken = request.getHeader(principalTokenHeader);
- if (rawToken == null || rawToken.isEmpty()) {
- sendErrorResponse(responseHandler, Response.Status.UNAUTHORIZED, "NToken is missing");
- return;
- }
try {
- AthenzPrincipal principal = validator.validate(new NToken(rawToken));
+ Optional<AthenzPrincipal> certificatePrincipal = getClientCertificate(request)
+ .map(AthenzUtils::createAthenzIdentity)
+ .map(AthenzPrincipal::new);
+ Optional<AthenzPrincipal> nTokenPrincipal = getPrincipalToken(request, principalTokenHeader)
+ .map(validator::validate);
+
+ if (!certificatePrincipal.isPresent() && !nTokenPrincipal.isPresent()) {
+ String errorMessage = "Unable to authenticate Athenz identity. " +
+ "Both client certificate missing and principal token header are missing.";
+ sendErrorResponse(responseHandler, Response.Status.UNAUTHORIZED, errorMessage);
+ return;
+ }
+ if (certificatePrincipal.isPresent() && nTokenPrincipal.isPresent()
+ && !certificatePrincipal.get().getIdentity().equals(nTokenPrincipal.get().getIdentity())) {
+ String errorMessage = String.format(
+ "Identity in principal token does not match x509 CN: token-identity=%s, cert-identity=%s",
+ nTokenPrincipal.get().getIdentity().getFullName(),
+ certificatePrincipal.get().getIdentity().getFullName());
+ sendErrorResponse(responseHandler, Response.Status.UNAUTHORIZED, errorMessage);
+ return;
+ }
+ AthenzPrincipal principal = nTokenPrincipal.orElseGet(certificatePrincipal::get);
request.setUserPrincipal(principal);
request.setRemoteUser(principal.getName());
- } catch (InvalidTokenException e) {
+ } catch (Exception e) {
sendErrorResponse(responseHandler,Response.Status.UNAUTHORIZED, e.getMessage());
}
}
+ private static Optional<X509Certificate> getClientCertificate(DiscFilterRequest request) {
+ return Optional.ofNullable((X509Certificate[]) request.getAttribute("jdisc.request.X509Certificate"))
+ .map(chain -> chain[0]);
+ }
+
+ private static Optional<NToken> getPrincipalToken(DiscFilterRequest request, String principalTokenHeaderName) {
+ return Optional.ofNullable(request.getHeader(principalTokenHeaderName))
+ .filter(token -> !token.isEmpty())
+ .map(NToken::new);
+ }
+
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java
index 487cbc02acc..6a9db3ae917 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java
@@ -32,9 +32,9 @@ import java.util.logging.Logger;
/**
* Responsible for scheduling deployment jobs in a build system and keeping
* Application.deploying() in sync with what is scheduled.
- *
+ *
* This class is multithread safe.
- *
+ *
* @author bratseth
* @author mpolden
*/
@@ -60,7 +60,7 @@ public class DeploymentTrigger {
this.order = new DeploymentOrder(controller);
this.jobTimeout = controller.system().equals(SystemName.main) ? Duration.ofHours(12) : Duration.ofHours(1);
}
-
+
/** Returns the time in the past before which jobs are at this moment considered unresponsive */
public Instant jobTimeoutLimit() { return clock.instant().minus(jobTimeout); }
@@ -70,10 +70,10 @@ public class DeploymentTrigger {
//--- Start of methods which triggers deployment jobs -------------------------
- /**
+ /**
* Called each time a job completes (successfully or not) to cause triggering of one or more follow-up jobs
* (which may possibly the same job once over).
- *
+ *
* @param report information about the job that just completed
*/
public void triggerFromCompletion(JobReport report) {
@@ -143,10 +143,11 @@ public class DeploymentTrigger {
JobStatus systemTestStatus = application.deploymentJobs().jobStatus().get(JobType.systemTest);
if (application.deploying().get() instanceof Change.VersionChange) {
Version target = ((Change.VersionChange) application.deploying().get()).version();
- if (systemTestStatus == null
+ if (systemTestStatus == null
|| ! systemTestStatus.lastTriggered().isPresent()
|| ! systemTestStatus.isSuccess()
- || ! systemTestStatus.lastTriggered().get().version().equals(target)) {
+ || ! systemTestStatus.lastTriggered().get().version().equals(target)
+ || systemTestStatus.isHanging(jobTimeoutLimit())) {
application = trigger(JobType.systemTest, application, false, "Upgrade to " + target);
controller.applications().store(application);
}
@@ -170,7 +171,7 @@ public class DeploymentTrigger {
List<JobType> nextToTrigger = new ArrayList<>();
for (JobType nextJobType : order.nextAfter(jobType, application)) {
JobStatus nextStatus = application.deploymentJobs().jobStatus().get(nextJobType);
- if (changesAvailable(application, jobStatus, nextStatus))
+ if (changesAvailable(application, jobStatus, nextStatus) || nextStatus.isHanging(jobTimeoutLimit()))
nextToTrigger.add(nextJobType);
}
// Trigger them in parallel
@@ -209,10 +210,10 @@ public class DeploymentTrigger {
return true;
return false;
}
-
+
/**
* Triggers a change of this application
- *
+ *
* @param applicationId the application to trigger
* @throws IllegalArgumentException if this application already have an ongoing change
*/
@@ -267,7 +268,7 @@ public class DeploymentTrigger {
}
/**
- * Trigger a job for an application
+ * Trigger a job for an application
*
* @param jobType the type of the job to trigger, or null to trigger nothing
* @param application the application to trigger the job for
@@ -289,7 +290,7 @@ public class DeploymentTrigger {
/**
* Trigger a job for an application, if allowed
- *
+ *
* @param jobType the type of the job to trigger, or null to trigger nothing
* @param application the application to trigger the job for
* @param first whether to trigger the job before other jobs
@@ -323,7 +324,7 @@ public class DeploymentTrigger {
/** Returns true if the given proposed job triggering should be effected */
private boolean allowedTriggering(JobType jobType, LockedApplication application) {
- // Note: We could make a more fine-grained and more correct determination about whether to block
+ // Note: We could make a more fine-grained and more correct determination about whether to block
// by instead basing the decision on what is currently deployed in the zone. However,
// this leads to some additional corner cases, and the possibility of blocking an application
// fix to a version upgrade, so not doing it now
@@ -341,7 +342,7 @@ public class DeploymentTrigger {
return true;
}
-
+
private boolean isRunningProductionJob(Application application) {
return JobList.from(application)
.production()
@@ -364,7 +365,7 @@ public class DeploymentTrigger {
if (existingDeployment == null) return false;
return existingDeployment.version().isAfter(version);
}
-
+
private boolean acceptNewRevisionNow(LockedApplication application) {
if ( ! application.deploying().isPresent()) return true;
@@ -377,5 +378,5 @@ public class DeploymentTrigger {
// Otherwise, the application is currently upgrading, without failures, and we should wait with the revision.
return false;
}
-
+
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java
index b7080a763f0..77ce49eaf47 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java
@@ -9,14 +9,13 @@ import com.yahoo.vespa.hosted.controller.api.identifiers.AthenzDomain;
import com.yahoo.vespa.hosted.controller.api.identifiers.TenantId;
import com.yahoo.vespa.hosted.controller.api.identifiers.UserGroup;
import com.yahoo.vespa.hosted.controller.api.identifiers.UserId;
-import com.yahoo.vespa.hosted.controller.api.integration.entity.EntityService;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzClientFactory;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzIdentity;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzPrincipal;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzUser;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.NToken;
+import com.yahoo.vespa.hosted.controller.api.integration.entity.EntityService;
import com.yahoo.vespa.hosted.controller.common.ContextAttributes;
-import com.yahoo.vespa.hosted.controller.restapi.filter.NTokenRequestFilter;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.HttpMethod;
@@ -78,8 +77,7 @@ public class Authorizer {
}
public Optional<NToken> getNToken(HttpRequest request) {
- String nTokenHeader = (String)request.getJDiscRequest().context().get(NTokenRequestFilter.NTOKEN_HEADER);
- return Optional.ofNullable(nTokenHeader).map(NToken::new);
+ return getPrincipalIfAny(request).flatMap(AthenzPrincipal::getNToken);
}
public boolean isSuperUser(HttpRequest request) {
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java
index ffb78b7342a..c887fbfc1a8 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java
@@ -7,10 +7,19 @@ import com.yahoo.jdisc.handler.ReadableContentChannel;
import com.yahoo.jdisc.handler.ResponseHandler;
import com.yahoo.jdisc.http.filter.DiscFilterRequest;
import com.yahoo.vespa.hosted.controller.api.identifiers.UserId;
+import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzIdentity;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzPrincipal;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzUser;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.InvalidTokenException;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.NToken;
+import org.bouncycastle.asn1.x500.X500Name;
+import org.bouncycastle.cert.X509v3CertificateBuilder;
+import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
+import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder;
+import org.bouncycastle.jce.provider.BouncyCastleProvider;
+import org.bouncycastle.operator.ContentSigner;
+import org.bouncycastle.operator.OperatorCreationException;
+import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
import org.junit.Before;
import org.junit.Test;
@@ -18,6 +27,15 @@ import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UncheckedIOException;
+import java.math.BigInteger;
+import java.security.KeyPair;
+import java.security.KeyPairGenerator;
+import java.security.NoSuchAlgorithmException;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Date;
import java.util.Objects;
import static com.yahoo.jdisc.Response.Status.UNAUTHORIZED;
@@ -37,21 +55,21 @@ public class AthenzPrincipalFilterTest {
private static final NToken NTOKEN = new NToken("dummy");
private static final String ATHENZ_PRINCIPAL_HEADER = "Athenz-Principal-Auth";
+ private static final AthenzIdentity IDENTITY = AthenzUser.fromUserId(new UserId("bob"));
+ private static final X509Certificate CERTIFICATE = createSelfSignedCertificate(IDENTITY);
private NTokenValidator validator;
- private AthenzPrincipal principal;
@Before
public void before() {
validator = mock(NTokenValidator.class);
- principal = new AthenzPrincipal(AthenzUser.fromUserId(new UserId("bob")), NTOKEN);
}
@Test
- public void valid_ntoken_is_accepted() throws Exception {
+ public void valid_ntoken_is_accepted() {
DiscFilterRequest request = mock(DiscFilterRequest.class);
+ AthenzPrincipal principal = new AthenzPrincipal(IDENTITY, NTOKEN);
when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken());
-
when(validator.validate(NTOKEN)).thenReturn(principal);
AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER);
@@ -61,7 +79,7 @@ public class AthenzPrincipalFilterTest {
}
@Test
- public void missing_token_is_unauthorized() throws Exception {
+ public void missing_token_and_certificate_is_unauthorized() {
DiscFilterRequest request = mock(DiscFilterRequest.class);
when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(null);
@@ -70,26 +88,76 @@ public class AthenzPrincipalFilterTest {
AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER);
filter.filter(request, responseHandler);
- assertThat(responseHandler.response, notNullValue());
- assertThat(responseHandler.response.getStatus(), equalTo(UNAUTHORIZED));
- assertThat(responseHandler.getResponseContent(), containsString("NToken is missing"));
+ assertUnauthorized(responseHandler, "Unable to authenticate Athenz identity");
+ }
+
+ @Test
+ public void invalid_token_is_unauthorized() {
+ DiscFilterRequest request = mock(DiscFilterRequest.class);
+ String errorMessage = "Invalid token";
+ when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken());
+ when(validator.validate(NTOKEN)).thenThrow(new InvalidTokenException(errorMessage));
+
+ ResponseHandlerMock responseHandler = new ResponseHandlerMock();
+
+ AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER);
+ filter.filter(request, responseHandler);
+
+ assertUnauthorized(responseHandler, errorMessage);
+ }
+
+ @Test
+ public void certificate_is_accepted() {
+ DiscFilterRequest request = mock(DiscFilterRequest.class);
+ when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(null);
+ when(request.getAttribute("jdisc.request.X509Certificate")).thenReturn(new X509Certificate[]{CERTIFICATE});
+
+ ResponseHandlerMock responseHandler = new ResponseHandlerMock();
+
+ AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER);
+ filter.filter(request, responseHandler);
+
+ AthenzPrincipal expectedPrincipal = new AthenzPrincipal(IDENTITY);
+ verify(request).setUserPrincipal(expectedPrincipal);
}
@Test
- public void invalid_token_is_unauthorized() throws Exception {
+ public void both_ntoken_and_certificate_is_accepted() {
DiscFilterRequest request = mock(DiscFilterRequest.class);
+ AthenzPrincipal principalWithToken = new AthenzPrincipal(IDENTITY, NTOKEN);
when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken());
+ when(request.getAttribute("jdisc.request.X509Certificate")).thenReturn(new X509Certificate[]{CERTIFICATE});
+ when(validator.validate(NTOKEN)).thenReturn(principalWithToken);
+
+ ResponseHandlerMock responseHandler = new ResponseHandlerMock();
+
+ AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER);
+ filter.filter(request, responseHandler);
- when(validator.validate(NTOKEN)).thenThrow(new InvalidTokenException("Invalid token"));
+ verify(request).setUserPrincipal(principalWithToken);
+ }
+
+ @Test
+ public void conflicting_ntoken_and_certificate_is_unauthorized() {
+ DiscFilterRequest request = mock(DiscFilterRequest.class);
+ AthenzUser conflictingIdentity = AthenzUser.fromUserId(new UserId("mallory"));
+ when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken());
+ when(request.getAttribute("jdisc.request.X509Certificate"))
+ .thenReturn(new X509Certificate[]{createSelfSignedCertificate(conflictingIdentity)});
+ when(validator.validate(NTOKEN)).thenReturn(new AthenzPrincipal(IDENTITY));
ResponseHandlerMock responseHandler = new ResponseHandlerMock();
AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER);
filter.filter(request, responseHandler);
+ assertUnauthorized(responseHandler, "Identity in principal token does not match x509 CN");
+ }
+
+ private static void assertUnauthorized(ResponseHandlerMock responseHandler, String expectedMessageSubstring) {
assertThat(responseHandler.response, notNullValue());
assertThat(responseHandler.response.getStatus(), equalTo(UNAUTHORIZED));
- assertThat(responseHandler.getResponseContent(), containsString("Invalid token"));
+ assertThat(responseHandler.getResponseContent(), containsString(expectedMessageSubstring));
}
private static class ResponseHandlerMock implements ResponseHandler {
@@ -114,4 +182,24 @@ public class AthenzPrincipalFilterTest {
}
+ // TODO Move this to separate athenz module/bundle
+ private static X509Certificate createSelfSignedCertificate(AthenzIdentity identity) {
+ try {
+ KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA");
+ keyGen.initialize(512);
+ KeyPair keyPair = keyGen.genKeyPair();
+ ContentSigner contentSigner = new JcaContentSignerBuilder("SHA256WithRSA").build(keyPair.getPrivate());
+ X500Name x500Name = new X500Name("CN="+ identity.getFullName());
+ X509v3CertificateBuilder certificateBuilder =
+ new JcaX509v3CertificateBuilder(
+ x500Name, BigInteger.ONE, new Date(), Date.from(Instant.now().plus(Duration.ofDays(30))),
+ x500Name, keyPair.getPublic());
+ return new JcaX509CertificateConverter()
+ .setProvider(new BouncyCastleProvider())
+ .getCertificate(certificateBuilder.build(contentSigner));
+ } catch (CertificateException | NoSuchAlgorithmException | OperatorCreationException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
}
diff --git a/document/src/main/java/com/yahoo/document/DataType.java b/document/src/main/java/com/yahoo/document/DataType.java
index c8a04866aa9..abdbf394591 100644
--- a/document/src/main/java/com/yahoo/document/DataType.java
+++ b/document/src/main/java/com/yahoo/document/DataType.java
@@ -51,7 +51,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com
public final static PrimitiveDataType URI = new PrimitiveDataType("uri", 10, UriFieldValue.class, new UriFieldValue.Factory());
public final static NumericDataType BYTE = new NumericDataType("byte", 16, ByteFieldValue.class, ByteFieldValue.getFactory());
public final static PrimitiveDataType PREDICATE = new PrimitiveDataType("predicate", 20, PredicateFieldValue.class, PredicateFieldValue.getFactory());
- public final static int tensorDataTypeCode = 21; // All TensorDataType instances have id=21 but carries additional type information serialized separately
+ public final static int tensorDataTypeCode = 21; // All TensorDataType instances have id=21 but carries additional type information serialized separately
// ADDITIONAL parametrized types added at runtime: map, struct, array, weighted set, annotation reference, tensor
// Tags are converted to weightedset<string> when reading the search definition TODO: Remove it
@@ -99,7 +99,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com
/**
* Creates a field value by reflection
- *
+ *
* @param arg the value of the newly created field value
* @return a fully constructed value
*/
@@ -201,7 +201,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com
public static TensorDataType getTensor(TensorType type) {
return new TensorDataType(type);
}
-
+
public String getName() {
return name;
}
@@ -267,7 +267,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com
*/
public FieldPath buildFieldPath(String fieldPathString) {
if (fieldPathString.length() > 0) {
- throw new IllegalArgumentException("Datatype " + toString() +
+ throw new IllegalArgumentException("Datatype " + toString() +
" does not support further recursive structure: " + fieldPathString);
}
return new FieldPath();
diff --git a/document/src/main/java/com/yahoo/document/DocumentTypeManager.java b/document/src/main/java/com/yahoo/document/DocumentTypeManager.java
index 8c9318199d8..5fad35a2287 100644
--- a/document/src/main/java/com/yahoo/document/DocumentTypeManager.java
+++ b/document/src/main/java/com/yahoo/document/DocumentTypeManager.java
@@ -38,7 +38,7 @@ public class DocumentTypeManager {
// *Configured data types* (not built-in/primitive) indexed by their id
//
// *Primitive* data types are always available and have a single id.
- //
+ //
// *Built-in dynamic* types: The tensor type.
// Any tensor type has the same id and is always available just like primitive types.
// However, unlike primitive types, each tensor type is a separate DataType instance
@@ -112,7 +112,7 @@ public class DocumentTypeManager {
public DataType getDataType(String name) {
if (name.startsWith("tensor(")) // built-in dynamic
return new TensorDataType(TensorType.fromSpec(name));
-
+
List<DataType> foundTypes = new ArrayList<>();
for (DataType type : dataTypes.values()) {
if (type.getName().equalsIgnoreCase(name)) {
@@ -141,10 +141,10 @@ public class DocumentTypeManager {
}
public DataType getDataType(int code) { return getDataType(code, ""); }
-
+
/**
* Return a data type instance
- *
+ *
* @param code the code of the data type to return, which must be either built in or present in this manager
* @param detailedType detailed type information, or the empty string if none
* @return the appropriate DataType instance
@@ -183,7 +183,7 @@ public class DocumentTypeManager {
/**
* Register a single datatype. Re-registering an existing, but equal, datatype is ok.
- *
+ *
* @param type The datatype to register
*/
void registerSingleType(DataType type) {
@@ -280,7 +280,7 @@ public class DocumentTypeManager {
/**
* Returns a read only view of the registered data types
- *
+ *
* @return collection of types
*/
public Collection<DataType> getDataTypes() {
diff --git a/document/src/main/java/com/yahoo/document/TensorDataType.java b/document/src/main/java/com/yahoo/document/TensorDataType.java
index aefdc030a12..50e9cf0f60f 100644
--- a/document/src/main/java/com/yahoo/document/TensorDataType.java
+++ b/document/src/main/java/com/yahoo/document/TensorDataType.java
@@ -8,13 +8,13 @@ import com.yahoo.vespa.objects.Ids;
/**
* A DataType containing a tensor type
- *
+ *
* @author bratseth
*/
public class TensorDataType extends DataType {
private final TensorType tensorType;
-
+
// The global class identifier shared with C++.
public static int classId = registerClass(Ids.document + 59, TensorDataType.class);
@@ -47,5 +47,5 @@ public class TensorDataType extends DataType {
/** Returns the type of the tensor this field can hold */
public TensorType getTensorType() { return tensorType; }
-
+
}
diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
index ae8d5cf596a..1808396986e 100644
--- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
+++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
@@ -19,7 +19,7 @@ import java.util.Optional;
public class TensorFieldValue extends FieldValue {
private Optional<Tensor> tensor;
-
+
private final TensorDataType dataType;
/** Create an empty tensor field value */
@@ -66,7 +66,7 @@ public class TensorFieldValue extends FieldValue {
o.getClass().getName() + "'.");
}
}
-
+
public void assignTensor(Optional<Tensor> tensor) {
if (tensor.isPresent() && ! tensor.get().type().isAssignableTo(dataType.getTensorType()))
throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() +
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index f37fa5ea675..29ba244a9f1 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -146,9 +146,9 @@ public class JsonReaderTestCase {
}
{
DocumentType x = new DocumentType("testtensor");
- x.addField(new Field("mappedtensorfield",
+ x.addField(new Field("mappedtensorfield",
new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").build())));
- x.addField(new Field("indexedtensorfield",
+ x.addField(new Field("indexedtensorfield",
new TensorDataType(new TensorType.Builder().indexed("x").indexed("y").build())));
types.registerDocumentType(x);
}
@@ -1280,8 +1280,8 @@ public class JsonReaderTestCase {
return (DocumentPut) reader.next();
}
- private DocumentPut createPutWithMappedTensor(String inputTensor) {
- return createPutWithTensor(inputTensor, "mappedtensorfield");
+ private DocumentPut createPutWithMappedTensor(String inputTensor) {
+ return createPutWithTensor(inputTensor, "mappedtensorfield");
}
private DocumentPut createPutWithTensor(String inputTensor, String tensorFieldName) {
InputStream rawDoc = new ByteArrayInputStream(
diff --git a/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java b/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java
index 7104c1686f8..5c65b11a0c4 100644
--- a/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java
+++ b/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java
@@ -24,7 +24,7 @@ public class TensorFieldValueSerializationTestCase {
private final static TensorType tensorType = new TensorType.Builder().mapped("dimX").mapped("dimY").build();
private final static String TENSOR_FIELD = "my_tensor";
private final static String TENSOR_FILES = "src/test/resources/tensor/";
- private final static TestDocumentFactory docFactory = new TestDocumentFactory(createDocType(),
+ private final static TestDocumentFactory docFactory = new TestDocumentFactory(createDocType(),
"id:test:my_type::foo");
private static DocumentType createDocType() {
diff --git a/document/src/vespa/document/bucket/bucketspace.h b/document/src/vespa/document/bucket/bucketspace.h
index 1198b173a4b..99b510f7aff 100644
--- a/document/src/vespa/document/bucket/bucketspace.h
+++ b/document/src/vespa/document/bucket/bucketspace.h
@@ -16,15 +16,16 @@ class BucketSpace {
public:
using Type = uint64_t;
- BucketSpace(const BucketSpace&) noexcept = default;
- BucketSpace& operator=(const BucketSpace&) noexcept = default;
- explicit BucketSpace(Type id) noexcept : _id(id) {}
+ constexpr BucketSpace(const BucketSpace&) noexcept = default;
+ constexpr BucketSpace& operator=(const BucketSpace&) noexcept = default;
+ constexpr explicit BucketSpace(Type id) noexcept : _id(id) {}
- bool operator <(const BucketSpace& bucket) const noexcept { return _id < bucket._id; }
- bool operator==(const BucketSpace& bucket) const noexcept { return _id == bucket._id; }
- bool operator!=(const BucketSpace& bucket) const noexcept { return _id != bucket._id; }
+ constexpr bool operator <(const BucketSpace& bucket) const noexcept { return _id < bucket._id; }
+ constexpr bool operator==(const BucketSpace& bucket) const noexcept { return _id == bucket._id; }
+ constexpr bool operator!=(const BucketSpace& bucket) const noexcept { return _id != bucket._id; }
- Type getId() const noexcept { return _id; }
+ constexpr Type getId() const noexcept { return _id; }
+ constexpr bool valid() const noexcept { return (_id != 0); }
vespalib::string toString() const;
struct hash {
@@ -36,7 +37,8 @@ public:
/*
* Temporary placeholder value while wiring in use of BucketSpace in APIs.
*/
- static BucketSpace placeHolder() { return BucketSpace(0); }
+ static constexpr BucketSpace placeHolder() noexcept { return BucketSpace(1); }
+ static constexpr BucketSpace invalid() noexcept { return BucketSpace(0); }
private:
Type _id;
};
diff --git a/document/src/vespa/document/select/CMakeLists.txt b/document/src/vespa/document/select/CMakeLists.txt
index 6dadd35e98a..bc73498622d 100644
--- a/document/src/vespa/document/select/CMakeLists.txt
+++ b/document/src/vespa/document/select/CMakeLists.txt
@@ -1,10 +1,14 @@
# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-find_package(BISON REQUIRED)
-find_package(FLEX REQUIRED)
+find_package(BISON REQUIRED 3.0)
+find_package(FLEX REQUIRED 2.5)
-BISON_TARGET(DocSelParser grammar/parser.yy ${CMAKE_CURRENT_BINARY_DIR}/parser.cxx)
-FLEX_TARGET(DocSelLexer grammar/lexer.ll ${CMAKE_CURRENT_BINARY_DIR}/lexer.cxx)
+BISON_TARGET(DocSelParser grammar/parser.yy
+ ${CMAKE_CURRENT_BINARY_DIR}/parser.cxx
+ DEFINES_FILE ${CMAKE_CURRENT_BINARY_DIR}/parser.hxx)
+FLEX_TARGET(DocSelLexer grammar/lexer.ll
+ ${CMAKE_CURRENT_BINARY_DIR}/lexer.cxx
+ DEFINES_FILE ${CMAKE_CURRENT_BINARY_DIR}/lexer.hxx)
ADD_FLEX_BISON_DEPENDENCY(DocSelLexer DocSelParser)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
diff --git a/document/src/vespa/document/select/context.cpp b/document/src/vespa/document/select/context.cpp
index 6d9e0df157b..3a728db33f8 100644
--- a/document/src/vespa/document/select/context.cpp
+++ b/document/src/vespa/document/select/context.cpp
@@ -38,10 +38,14 @@ Context::~Context() { }
std::unique_ptr<Value>
Context::getValue(const vespalib::string & value) const {
- VariableMap::const_iterator iter = _variables->find(value);
-
- if (iter != _variables->end()) {
- return std::make_unique<FloatValue>(iter->second);
+ if (_variables) {
+ VariableMap::const_iterator iter = _variables->find(value);
+
+ if (iter != _variables->end()) {
+ return std::make_unique<FloatValue>(iter->second);
+ } else {
+ return std::make_unique<FloatValue>(0.0);
+ }
} else {
return std::make_unique<FloatValue>(0.0);
}
diff --git a/document/src/vespa/document/select/grammar/lexer.ll b/document/src/vespa/document/select/grammar/lexer.ll
index 8cd5638c122..6483b5e8534 100644
--- a/document/src/vespa/document/select/grammar/lexer.ll
+++ b/document/src/vespa/document/select/grammar/lexer.ll
@@ -1,9 +1,5 @@
/* Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. */
- /* We use the .*xx-suffix to denote a build-time generated file */
-%option outfile="lexer.cxx"
-%option header-file="lexer.hxx"
-
%option c++
/* Uncomment to enable debug tracing of parsing */
/* %option debug */
diff --git a/document/src/vespa/document/select/grammar/parser.yy b/document/src/vespa/document/select/grammar/parser.yy
index baf987355c9..f96bd50378f 100644
--- a/document/src/vespa/document/select/grammar/parser.yy
+++ b/document/src/vespa/document/select/grammar/parser.yy
@@ -1,8 +1,5 @@
/* Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. */
-%output "parser.cxx"
-%defines "parser.hxx"
-
/* Skeleton implementation included as part of the generated source. Note: _not_ covered by the GPL. */
%skeleton "lalr1.cc"
diff --git a/document/src/vespa/document/test/make_bucket_space.cpp b/document/src/vespa/document/test/make_bucket_space.cpp
index be8292fcf71..dae0399e75d 100644
--- a/document/src/vespa/document/test/make_bucket_space.cpp
+++ b/document/src/vespa/document/test/make_bucket_space.cpp
@@ -11,12 +11,12 @@ BucketSpace makeBucketSpace()
BucketSpace makeBucketSpace(const vespalib::string &docTypeName)
{
- // Used by persistence conformance test to map fron document type name
+ // Used by persistence conformance test to map from document type name
// to bucket space. See document::TestDocRepo for known document types.
if (docTypeName == "no") {
- return BucketSpace(2);
+ return BucketSpace(3);
} else if (docTypeName == "testdoctype2") {
- return BucketSpace(1);
+ return BucketSpace(2);
} else {
return makeBucketSpace();
}
diff --git a/documentapi/src/main/java/com/yahoo/documentapi/SyncParameters.java b/documentapi/src/main/java/com/yahoo/documentapi/SyncParameters.java
index 66af8061f7c..cbe322aef71 100755
--- a/documentapi/src/main/java/com/yahoo/documentapi/SyncParameters.java
+++ b/documentapi/src/main/java/com/yahoo/documentapi/SyncParameters.java
@@ -1,17 +1,17 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.documentapi;
-import java.time.temporal.TemporalAmount;
+import java.time.Duration;
import java.util.Optional;
/**
* Parameters for creating a synchronous session
*
* @author bjorncs
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
*/
public class SyncParameters extends Parameters {
- private final TemporalAmount defaultTimeout;
+ private final Duration defaultTimeout;
/**
* @deprecated Use {@link Builder} instead.
@@ -22,21 +22,21 @@ public class SyncParameters extends Parameters {
this(null);
}
- private SyncParameters(TemporalAmount defaultTimeout) {
+ private SyncParameters(Duration defaultTimeout) {
this.defaultTimeout = defaultTimeout;
}
- public Optional<TemporalAmount> defaultTimeout() {
+ public Optional<Duration> defaultTimeout() {
return Optional.ofNullable(defaultTimeout);
}
public static class Builder {
- private TemporalAmount defaultTimeout;
+ private Duration defaultTimeout;
/**
* Set default timeout for all messagebus operations.
*/
- public void setDefaultTimeout(TemporalAmount defaultTimeout) {
+ public void setDefaultTimeout(Duration defaultTimeout) {
this.defaultTimeout = defaultTimeout;
}
diff --git a/documentapi/src/main/java/com/yahoo/documentapi/SyncSession.java b/documentapi/src/main/java/com/yahoo/documentapi/SyncSession.java
index ee9b1760012..ca55933e302 100755
--- a/documentapi/src/main/java/com/yahoo/documentapi/SyncSession.java
+++ b/documentapi/src/main/java/com/yahoo/documentapi/SyncSession.java
@@ -8,13 +8,13 @@ import com.yahoo.document.DocumentRemove;
import com.yahoo.document.DocumentUpdate;
import com.yahoo.documentapi.messagebus.protocol.DocumentProtocol;
-import java.time.temporal.TemporalAmount;
+import java.time.Duration;
/**
* <p>A session for synchronous access to a document repository. This class
* provides simple document access where throughput is not a concern.</p>
*
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
* @author bjorncs
*/
public interface SyncSession extends Session {
@@ -71,7 +71,7 @@ public interface SyncSession extends Session {
* @throws DocumentAccessException on any messagebus error, including timeout ({@link com.yahoo.messagebus.ErrorCode#TIMEOUT}).
*/
// TODO Vespa 7: Remove default implementation. Consider removing get() overloads without timeout.
- default Document get(DocumentId id, TemporalAmount timeout) {
+ default Document get(DocumentId id, Duration timeout) {
return get(id);
}
@@ -88,8 +88,7 @@ public interface SyncSession extends Session {
* @throws DocumentAccessException on any messagebus error, including timeout ({@link com.yahoo.messagebus.ErrorCode#TIMEOUT}).
*/
// TODO Vespa 7: Remove default implementation. Consider removing get() overloads without timeout.
- default Document get(DocumentId id, String fieldSet, DocumentProtocol.Priority priority,
- TemporalAmount timeout) {
+ default Document get(DocumentId id, String fieldSet, DocumentProtocol.Priority priority, Duration timeout) {
return get(id, fieldSet, priority);
}
diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusSyncSession.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusSyncSession.java
index f2b1816a410..e02b6029dcf 100755
--- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusSyncSession.java
+++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusSyncSession.java
@@ -25,19 +25,18 @@ import com.yahoo.messagebus.MessageBus;
import com.yahoo.messagebus.Reply;
import com.yahoo.messagebus.ReplyHandler;
-import java.time.temporal.ChronoUnit;
-import java.time.temporal.TemporalAmount;
+import java.time.Duration;
/**
* An implementation of the SyncSession interface running over message bus.
*
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
* @author bjorncs
*/
public class MessageBusSyncSession implements MessageBusSession, SyncSession, ReplyHandler {
private final MessageBusAsyncSession session;
- private final TemporalAmount defaultTimeout;
+ private final Duration defaultTimeout;
/**
* Creates a new sync session running on message bus logic.
@@ -87,9 +86,9 @@ public class MessageBusSyncSession implements MessageBusSession, SyncSession, Re
return syncSend(msg, defaultTimeout);
}
- private Reply syncSend(Message msg, TemporalAmount timeout) {
+ private Reply syncSend(Message msg, Duration timeout) {
if (timeout != null) {
- msg.setTimeRemaining(timeout.get(ChronoUnit.MILLIS));
+ msg.setTimeRemaining(timeout.toMillis());
}
try {
RequestMonitor monitor = new RequestMonitor();
@@ -135,13 +134,12 @@ public class MessageBusSyncSession implements MessageBusSession, SyncSession, Re
}
@Override
- public Document get(DocumentId id, TemporalAmount timeout) {
+ public Document get(DocumentId id, Duration timeout) {
return get(id, "[all]", DocumentProtocol.Priority.NORMAL_1, timeout);
}
@Override
- public Document get(DocumentId id, String fieldSet, DocumentProtocol.Priority pri,
- TemporalAmount timeout) {
+ public Document get(DocumentId id, String fieldSet, DocumentProtocol.Priority pri, Duration timeout) {
GetDocumentMessage msg = new GetDocumentMessage(id, fieldSet);
msg.setPriority(pri);
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
index ca77997bac7..0b8b98fc617 100644
--- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
@@ -1,8 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include <vespa/log/log.h>
-LOG_SETUP("dense_dot_product_function_test");
-
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/tensor/dense/dense_dot_product_function.h>
@@ -12,16 +9,13 @@ LOG_SETUP("dense_dot_product_function_test");
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/stash.h>
+#include <vespa/log/log.h>
+LOG_SETUP("dense_dot_product_function_test");
+
using namespace vespalib;
using namespace vespalib::eval;
using namespace vespalib::tensor;
-ValueType
-makeType(size_t numCells)
-{
- return ValueType::tensor_type({{"x", numCells}});
-}
-
tensor::Tensor::UP
makeTensor(size_t numCells, double cellBias)
{
diff --git a/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp b/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp
index 6f3cdd5f93f..61efdbe6d22 100644
--- a/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp
+++ b/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp
@@ -147,7 +147,7 @@ TEST_F("require that builder can be re-used", Fixture)
}
void
-assertTensorCell(const std::vector<size_t> &expAddress,
+assertTensorCell(const DenseTensor::Address &expAddress,
double expCell,
const DenseTensor::CellsIterator &itr)
{
diff --git a/eval/src/tests/tensor/sparse_tensor_builder/sparse_tensor_builder_test.cpp b/eval/src/tests/tensor/sparse_tensor_builder/sparse_tensor_builder_test.cpp
index bd3d1ada017..708c2f761f7 100644
--- a/eval/src/tests/tensor/sparse_tensor_builder/sparse_tensor_builder_test.cpp
+++ b/eval/src/tests/tensor/sparse_tensor_builder/sparse_tensor_builder_test.cpp
@@ -2,9 +2,11 @@
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/eval/tensor/sparse/sparse_tensor_builder.h>
+#include <vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h>
#include <vespa/vespalib/test/insertion_operators.h>
using namespace vespalib::tensor;
+using namespace vespalib::tensor::sparse;
using vespalib::eval::TensorSpec;
using vespalib::eval::ValueType;
@@ -57,10 +59,8 @@ TEST("require that tensor can be constructed")
const ValueType &type = sparseTensor.type();
const SparseTensor::Cells &cells = sparseTensor.cells();
EXPECT_EQUAL(2u, cells.size());
- assertCellValue(10, TensorAddress({{"a","1"},{"b","2"}}),
- type, cells);
- assertCellValue(20, TensorAddress({{"c","3"},{"d","4"}}),
- type, cells);
+ assertCellValue(10, TensorAddress({{"a","1"},{"b","2"}}), type, cells);
+ assertCellValue(20, TensorAddress({{"c","3"},{"d","4"}}), type, cells);
}
TEST("require that tensor can be converted to tensor spec")
@@ -94,4 +94,22 @@ TEST("require that dimensions are extracted")
EXPECT_EQUAL("tensor(a{},b{},c{})", sparseTensor.type().to_spec());
}
+void verifyAddressCombiner(const ValueType & a, const ValueType & b, size_t numDim, size_t numOverlapping) {
+ TensorAddressCombiner combiner(a, b);
+ EXPECT_EQUAL(numDim, combiner.numDimensions());
+ EXPECT_EQUAL(numOverlapping, combiner.numOverlappingDimensions());
+}
+TEST("Test sparse tensor address combiner") {
+ verifyAddressCombiner(ValueType::tensor_type({{"a"}}), ValueType::tensor_type({{"b"}}), 2, 0);
+ verifyAddressCombiner(ValueType::tensor_type({{"a"}, {"b"}}), ValueType::tensor_type({{"b"}}), 2, 1);
+ verifyAddressCombiner(ValueType::tensor_type({{"a"}, {"b"}}), ValueType::tensor_type({{"b"}, {"c"}}), 3, 1);
+
+}
+
+TEST("Test essential object sizes") {
+ EXPECT_EQUAL(16u, sizeof(SparseTensorAddressRef));
+ EXPECT_EQUAL(24u, sizeof(std::pair<SparseTensorAddressRef, double>));
+ EXPECT_EQUAL(32u, sizeof(vespalib::hash_node<std::pair<SparseTensorAddressRef, double>>));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h
index 52a0fbabd22..05c974bd3ff 100644
--- a/eval/src/vespa/eval/eval/operation.h
+++ b/eval/src/vespa/eval/eval/operation.h
@@ -7,10 +7,8 @@
#include <vespa/vespalib/util/approx.h>
#include <vespa/vespalib/util/stash.h>
-namespace vespalib {
-namespace eval {
+namespace vespalib::eval::operation {
-namespace operation {
struct Neg { static double f(double a); };
struct Not { static double f(double a); };
struct Add { static double f(double a, double b); };
@@ -52,7 +50,5 @@ struct IsNan { static double f(double a); };
struct Relu { static double f(double a); };
struct Sigmoid { static double f(double a); };
struct Elu { static double f(double a); };
-} // namespace vespalib::eval::operation
-} // namespace vespalib::eval
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/eval/simple_tensor.cpp b/eval/src/vespa/eval/eval/simple_tensor.cpp
index 0e58d292334..1836f2088f3 100644
--- a/eval/src/vespa/eval/eval/simple_tensor.cpp
+++ b/eval/src/vespa/eval/eval/simple_tensor.cpp
@@ -57,14 +57,14 @@ Address select(const Address &a, const Address &b, const IndexList &selector) {
return result;
}
-size_t get_dimension_size(const ValueType &type, size_t dim_idx) {
+size_t get_dimension_size(const ValueType &type, ValueType::Dimension::size_type dim_idx) {
if (dim_idx == ValueType::Dimension::npos) {
return 1;
}
return type.dimensions()[dim_idx].size;
}
-size_t get_dimension_index(const Address &addr, size_t dim_idx) {
+size_t get_dimension_index(const Address &addr, ValueType::Dimension::size_type dim_idx) {
if (dim_idx == ValueType::Dimension::npos) {
return 0;
}
diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp
index 03e6d2bbcdf..1c4973a78ca 100644
--- a/eval/src/vespa/eval/eval/value_type.cpp
+++ b/eval/src/vespa/eval/eval/value_type.cpp
@@ -101,9 +101,9 @@ struct Renamer {
} // namespace vespalib::tensor::<unnamed>
-constexpr size_t ValueType::Dimension::npos;
+constexpr ValueType::Dimension::size_type ValueType::Dimension::npos;
-ValueType::~ValueType() { }
+ValueType::~ValueType() = default;
bool
ValueType::is_sparse() const
{
diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h
index 2988cc5204e..a4762acd4c0 100644
--- a/eval/src/vespa/eval/eval/value_type.h
+++ b/eval/src/vespa/eval/eval/value_type.h
@@ -6,8 +6,7 @@
#include <vector>
#include <memory>
-namespace vespalib {
-namespace eval {
+namespace vespalib::eval {
/**
* The type of a Value. This is used for type-resolution during
@@ -19,12 +18,13 @@ class ValueType
public:
enum class Type { ANY, ERROR, DOUBLE, TENSOR };
struct Dimension {
- static constexpr size_t npos = -1;
+ using size_type = uint32_t;
+ static constexpr size_type npos = -1;
vespalib::string name;
- size_t size;
+ size_type size;
Dimension(const vespalib::string &name_in)
: name(name_in), size(npos) {}
- Dimension(const vespalib::string &name_in, size_t size_in)
+ Dimension(const vespalib::string &name_in, size_type size_in)
: name(name_in), size(size_in) {}
bool operator==(const Dimension &rhs) const {
return ((name == rhs.name) && (size == rhs.size));
@@ -91,5 +91,4 @@ public:
std::ostream &operator<<(std::ostream &os, const ValueType &type);
-} // namespace vespalib::eval
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/cell_function.h b/eval/src/vespa/eval/tensor/cell_function.h
index d758cf60634..a268c9a34b1 100644
--- a/eval/src/vespa/eval/tensor/cell_function.h
+++ b/eval/src/vespa/eval/tensor/cell_function.h
@@ -4,8 +4,7 @@
#include <functional>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* Interface for a function to be applied on cells in a tensor.
@@ -17,5 +16,4 @@ struct CellFunction
virtual double apply(double value) const = 0;
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
index 992f2eae750..fdd0cd6638f 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
@@ -6,8 +6,7 @@
#include <vespa/eval/eval/value.h>
#include <vespa/eval/tensor/tensor.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
using CellsRef = DenseTensorView::CellsRef;
@@ -39,5 +38,5 @@ DenseDotProductFunction::eval(ConstArrayRef<eval::Value::CREF> params, Stash &st
return stash.create<eval::DoubleValue>(result);
}
-} // namespace tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h
index 8ad57d69524..288f2afd084 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h
@@ -5,8 +5,7 @@
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/vespalib/hwaccelrated/iaccelrated.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* Tensor function for a dot product between two 1-dimensional dense tensors.
@@ -27,5 +26,5 @@ public:
const eval::Value &eval(ConstArrayRef<eval::Value::CREF> params, Stash &stash) const override;
};
-} // namespace tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp
index 5d7e0c83267..9693e89bb75 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp
@@ -4,12 +4,10 @@
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/exceptions.h>
#include <vespa/eval/eval/operation.h>
-#include <sstream>
using vespalib::eval::TensorSpec;
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
namespace {
@@ -84,5 +82,5 @@ DenseTensor::operator==(const DenseTensor &rhs) const
(_cells == rhs._cells);
}
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor.h b/eval/src/vespa/eval/tensor/dense/dense_tensor.h
index 1b97438272e..c45d3c7ccb6 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor.h
@@ -8,8 +8,7 @@
#include "dense_tensor_cells_iterator.h"
#include "dense_tensor_view.h"
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* A dense tensor where all dimensions are indexed.
@@ -29,16 +28,13 @@ private:
public:
DenseTensor();
~DenseTensor() {}
- DenseTensor(const eval::ValueType &type_in,
- const Cells &cells_in);
- DenseTensor(const eval::ValueType &type_in,
- Cells &&cells_in);
- DenseTensor(eval::ValueType &&type_in,
- Cells &&cells_in);
+ DenseTensor(const eval::ValueType &type_in, const Cells &cells_in);
+ DenseTensor(const eval::ValueType &type_in, Cells &&cells_in);
+ DenseTensor(eval::ValueType &&type_in, Cells &&cells_in);
bool operator==(const DenseTensor &rhs) const;
const Cells &cells() const { return _cells; }
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
index 3e9f4f619f0..ef2a56d4582 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
@@ -4,33 +4,7 @@
#include <vespa/vespalib/util/exceptions.h>
#include <cassert>
-namespace vespalib {
-namespace tensor {
-
-using Address = DenseTensorAddressCombiner::Address;
-
-namespace {
-
-class AddressReader
-{
-private:
- const Address &_address;
- size_t _idx;
-
-public:
- AddressReader(const Address &address)
- : _address(address),
- _idx(0)
- {}
- size_t nextLabel() {
- return _address[_idx++];
- }
- bool valid() {
- return _idx < _address.size();
- }
-};
-
-}
+namespace vespalib::tensor {
DenseTensorAddressCombiner::~DenseTensorAddressCombiner() { }
@@ -57,35 +31,7 @@ DenseTensorAddressCombiner::DenseTensorAddressCombiner(const eval::ValueType &lh
_ops.push_back(AddressOp::RHS);
++rhsItr;
}
-}
-
-bool
-DenseTensorAddressCombiner::combine(const CellsIterator &lhsItr,
- const CellsIterator &rhsItr)
-{
- _combinedAddress.clear();
- AddressReader lhsReader(lhsItr.address());
- AddressReader rhsReader(rhsItr.address());
- for (const auto &op : _ops) {
- switch (op) {
- case AddressOp::LHS:
- _combinedAddress.emplace_back(lhsReader.nextLabel());
- break;
- case AddressOp::RHS:
- _combinedAddress.emplace_back(rhsReader.nextLabel());
- break;
- case AddressOp::BOTH:
- size_t lhsLabel = lhsReader.nextLabel();
- size_t rhsLabel = rhsReader.nextLabel();
- if (lhsLabel != rhsLabel) {
- return false;
- }
- _combinedAddress.emplace_back(lhsLabel);
- }
- }
- assert(!lhsReader.valid());
- assert(!rhsReader.valid());
- return true;
+ _combinedAddress.resize(_ops.size());
}
eval::ValueType
@@ -120,5 +66,4 @@ DenseTensorAddressCombiner::combineDimensions(const eval::ValueType &lhs,
eval::ValueType::tensor_type(std::move(result)));
}
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
index 30bfd740fdd..37fad083dc1 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
@@ -7,8 +7,7 @@
#include <vespa/eval/tensor/types.h>
#include <vespa/eval/eval/value_type.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
@@ -19,32 +18,57 @@ namespace tensor {
class DenseTensorAddressCombiner
{
public:
- using Address = std::vector<size_t>;
+ using Address = DenseTensorCellsIterator::Address;
private:
- enum class AddressOp {
- LHS,
- RHS,
- BOTH
- };
+ enum class AddressOp { LHS, RHS, BOTH };
using CellsIterator = DenseTensorCellsIterator;
std::vector<AddressOp> _ops;
Address _combinedAddress;
+ class AddressReader
+ {
+ private:
+ const Address &_address;
+ uint32_t _idx;
+ public:
+ AddressReader(const Address &address) : _address(address), _idx(0) {}
+ Address::value_type nextLabel() { return _address[_idx++]; }
+ };
public:
- DenseTensorAddressCombiner(const eval::ValueType &lhs,
- const eval::ValueType &rhs);
+ DenseTensorAddressCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs);
~DenseTensorAddressCombiner();
- bool combine(const CellsIterator &lhsItr,
- const CellsIterator &rhsItr);
const Address &address() const { return _combinedAddress; }
- static eval::ValueType combineDimensions(const eval::ValueType &lhs, const eval::ValueType &rhs);
+ bool combine(const CellsIterator &lhsItr, const CellsIterator &rhsItr) {
+ uint32_t index(0);
+ AddressReader lhsReader(lhsItr.address());
+ AddressReader rhsReader(rhsItr.address());
+ for (const auto &op : _ops) {
+ switch (op) {
+ case AddressOp::LHS:
+ _combinedAddress[index] = lhsReader.nextLabel();
+ break;
+ case AddressOp::RHS:
+ _combinedAddress[index] = rhsReader.nextLabel();
+ break;
+ case AddressOp::BOTH:
+ Address::value_type lhsLabel = lhsReader.nextLabel();
+ Address::value_type rhsLabel = rhsReader.nextLabel();
+ if (lhsLabel != rhsLabel) {
+ return false;
+ }
+ _combinedAddress[index] = lhsLabel;
+ }
+ index++;
+ }
+ return true;
+ }
+ static eval::ValueType combineDimensions(const eval::ValueType &lhs, const eval::ValueType &rhs);
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.h
index 36432c420f5..49e075f6999 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.h
@@ -2,13 +2,12 @@
#pragma once
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
+ class Tensor;
+ class DenseTensor;
+}
-class Tensor;
-class DenseTensor;
-
-namespace dense {
+namespace vespalib::tensor::dense {
/**
* Creates a new tensor using all combinations of input tensor cells with matching
@@ -22,7 +21,4 @@ template <typename Function>
std::unique_ptr<Tensor>
apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func);
-} // namespace vespalib::tensor::dense
-} // namespace vespalib::tensor
-} // namespace vespalib
-
+}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
index 65fee767690..dc47d02d47c 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
@@ -6,9 +6,7 @@
#include "dense_tensor_address_combiner.h"
#include "direct_dense_tensor_builder.h"
-namespace vespalib {
-namespace tensor {
-namespace dense {
+namespace vespalib::tensor::dense {
template <typename Function>
std::unique_ptr<Tensor>
@@ -42,6 +40,4 @@ apply(const DenseTensorView &lhs, const Tensor &rhs, Function &&func)
return Tensor::UP();
}
-} // namespace vespalib::tensor::dense
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp
index 0b66dd51206..5d52e5f6e0e 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp
@@ -83,7 +83,7 @@ DenseTensorBuilder::calculateCellAddress()
const auto &dim = _dimensions[i];
if (label == UNDEFINED_LABEL) {
throw IllegalArgumentException(make_string("Label for dimension '%s' is undefined. "
- "Expected a value in the range [0, %zu>",
+ "Expected a value in the range [0, %u>",
dim.name.c_str(), dim.size));
}
result += (label * multiplier);
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h
index 765ed57393a..3969a9335b8 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h
@@ -6,8 +6,7 @@
#include <vespa/vespalib/stllike/hash_map.h>
#include <vespa/eval/tensor/tensor_builder.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* A builder of for dense tensors.
@@ -38,5 +37,5 @@ public:
Tensor::UP build();
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.cpp
index 59b4646a22b..d20c5124330 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.cpp
@@ -2,23 +2,14 @@
#include "dense_tensor_cells_iterator.h"
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
-void
-DenseTensorCellsIterator::next()
-{
- ++_cellIdx;
- if (valid()) {
- for (int64_t i = (_address.size() - 1); i >= 0; --i) {
- _address[i] = (_address[i] + 1) % _type.dimensions()[i].size;
- if (_address[i] != 0) {
- // Outer dimension labels can only be increased when this label wraps around.
- break;
- }
- }
- }
-}
+DenseTensorCellsIterator::DenseTensorCellsIterator(const eval::ValueType &type_in, CellsRef cells)
+ : _type(type_in),
+ _cells(cells),
+ _cellIdx(0),
+ _address(type_in.dimensions().size(), 0)
+{}
+DenseTensorCellsIterator::~DenseTensorCellsIterator() = default;
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h
index f77517bfdc5..fcffecef764 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h
@@ -8,34 +8,41 @@
#include <vespa/eval/tensor/tensor.h>
#include <vespa/vespalib/util/arrayref.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* Utility class to iterate over cells in a dense tensor.
*/
class DenseTensorCellsIterator
{
+public:
+ using size_type = eval::ValueType::Dimension::size_type;
+ using Address = std::vector<size_type>;
private:
using CellsRef = vespalib::ConstArrayRef<double>;
const eval::ValueType &_type;
CellsRef _cells;
- size_t _cellIdx;
- std::vector<size_t> _address;
-
+ size_t _cellIdx;
+ Address _address;
public:
- DenseTensorCellsIterator(const eval::ValueType &type_in, CellsRef cells)
- : _type(type_in),
- _cells(cells),
- _cellIdx(0),
- _address(type_in.dimensions().size(), 0)
- {}
+ DenseTensorCellsIterator(const eval::ValueType &type_in, CellsRef cells);
+ ~DenseTensorCellsIterator();
+ void next() {
+ ++_cellIdx;
+ for (int64_t i = (_address.size() - 1); i >= 0; --i) {
+ _address[i]++;
+ if (__builtin_expect((_address[i] != _type.dimensions()[i].size), true)) {
+ // Outer dimension labels can only be increased when this label wraps around.
+ break;
+ } else {
+ _address[i] = 0;
+ }
+ }
+ }
bool valid() const { return _cellIdx < _cells.size(); }
- void next();
double cell() const { return _cells[_cellIdx]; }
- const std::vector<size_t> &address() const { return _address; }
+ const Address &address() const { return _address; }
const eval::ValueType &fast_type() const { return _type; }
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.cpp
index 1268a46b8e5..22e2a3fb78c 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.cpp
@@ -11,8 +11,7 @@ using namespace vespalib::eval;
using namespace vespalib::eval::tensor_function;
using namespace vespalib::eval::operation;
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
namespace {
@@ -89,5 +88,5 @@ DenseTensorFunctionCompiler::compile(const eval::tensor_function::Node &expr, St
return InnerProductFunctionCompiler::compile(expr, stash);
}
-} // namespace tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.h
index d5ba4e4f7a7..61c3af079e3 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.h
@@ -4,11 +4,9 @@
#include <vespa/eval/eval/tensor_function.h>
-namespace vespalib {
+namespace vespalib { class Stash; }
-class Stash;
-
-namespace tensor {
+namespace vespalib::tensor {
/**
* Class that recognizes calculations over dense tensors (in tensor function intermediate representation)
@@ -19,5 +17,5 @@ struct DenseTensorFunctionCompiler
static const eval::TensorFunction &compile(const eval::tensor_function::Node &expr, Stash &stash);
};
-} // namespace tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.h
index d8f47d2234c..fb054318985 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.h
@@ -4,9 +4,7 @@
#include "dense_tensor.h"
-namespace vespalib {
-namespace tensor {
-namespace dense {
+namespace vespalib::tensor::dense {
/**
* Returns a tensor with the given dimension(s) removed and the cell values in that dimension(s)
@@ -16,6 +14,5 @@ template<typename Function>
std::unique_ptr<Tensor>
reduce(const DenseTensorView &tensor, const std::vector<vespalib::string> &dimensions, Function &&func);
-} // namespace dense
-} // namespace tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
index 30c9f17348e..74c8981168d 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
@@ -13,8 +13,7 @@
using vespalib::eval::TensorSpec;
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
namespace {
@@ -228,7 +227,7 @@ DenseTensorView::accept(TensorVisitor &visitor) const
addressBuilder.clear();
auto rawIndex = iterator.address().begin();
for (const auto &dimension : _typeRef.dimensions()) {
- label = vespalib::make_string("%zu", *rawIndex);
+ label = vespalib::make_string("%u", *rawIndex);
addressBuilder.add(dimension.name, label);
++rawIndex;
}
@@ -264,5 +263,4 @@ DenseTensorView::reduce(join_fun_t op,
op);
}
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
index 5a59594667d..fd95c8555f4 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
@@ -7,8 +7,7 @@
#include <vespa/eval/eval/value_type.h>
#include "dense_tensor_cells_iterator.h"
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
class DenseTensor;
@@ -22,6 +21,7 @@ public:
using Cells = std::vector<double>;
using CellsRef = ConstArrayRef<double>;
using CellsIterator = DenseTensorCellsIterator;
+ using Address = std::vector<eval::ValueType::Dimension::size_type>;
private:
const eval::ValueType &_typeRef;
@@ -61,5 +61,5 @@ public:
virtual void accept(TensorVisitor &visitor) const override;
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp
index 45de00dc7fe..1ab78b8ee30 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp
@@ -8,8 +8,7 @@
#include <vespa/vespalib/util/exceptions.h>
#include <assert.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
DenseXWProductFunction::DenseXWProductFunction(const eval::ValueType &resultType,
size_t vectorId,
@@ -87,5 +86,5 @@ DenseXWProductFunction::eval(ConstArrayRef<eval::Value::CREF> params, Stash &sta
return stash.create<DenseTensorView>(_resultType, outputCells);
}
-} // namespace tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h
index db006100e5a..151f1f13800 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h
@@ -6,8 +6,7 @@
#include "dense_tensor_view.h"
#include <vespa/vespalib/hwaccelrated/iaccelrated.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
using XWInput = DenseTensorView::CellsRef;
using XWOutput = ArrayRef<double>;
@@ -49,5 +48,5 @@ public:
const eval::Value &eval(ConstArrayRef<eval::Value::CREF> params, Stash &stash) const override;
};
-} // namespace tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.cpp b/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.cpp
index f73d123d4bd..27d72e18f96 100644
--- a/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.cpp
+++ b/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.cpp
@@ -3,8 +3,7 @@
#include "direct_dense_tensor_builder.h"
#include <cassert>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
using Address = DirectDenseTensorBuilder::Address;
using eval::ValueType;
@@ -35,7 +34,7 @@ calculateCellAddress(const Address &address, const ValueType &type)
}
-DirectDenseTensorBuilder::~DirectDenseTensorBuilder() { }
+DirectDenseTensorBuilder::~DirectDenseTensorBuilder() = default;
DirectDenseTensorBuilder::DirectDenseTensorBuilder(const ValueType &type_in)
: _type(type_in),
@@ -57,5 +56,5 @@ DirectDenseTensorBuilder::build()
return std::make_unique<DenseTensor>(std::move(_type), std::move(_cells));
}
-} // namespace tensor
-} // namesapce vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.h b/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.h
index 5e0368e8e69..865decd9fb8 100644
--- a/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.h
+++ b/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.h
@@ -4,8 +4,7 @@
#include "dense_tensor.h"
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* Class for building a dense tensor by inserting cell values directly into underlying array of cells.
@@ -14,7 +13,7 @@ class DirectDenseTensorBuilder
{
public:
using Cells = DenseTensor::Cells;
- using Address = std::vector<size_t>;
+ using Address = DenseTensor::Address;
private:
eval::ValueType _type;
@@ -27,5 +26,5 @@ public:
Tensor::UP build();
};
-} // namespace tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.cpp
index 71b7824ee5d..e3b4c8dee42 100644
--- a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.cpp
+++ b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.cpp
@@ -4,8 +4,7 @@
using vespalib::eval::ValueType;
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
MutableDenseTensorView::MutableValueType::MutableValueType(ValueType type_in)
: _type(type_in)
@@ -19,7 +18,7 @@ MutableDenseTensorView::MutableValueType::MutableValueType(ValueType type_in)
}
}
-MutableDenseTensorView::MutableValueType::~MutableValueType() {}
+MutableDenseTensorView::MutableValueType::~MutableValueType() = default;
MutableDenseTensorView::MutableDenseTensorView(ValueType type_in)
: DenseTensorView(_concreteType.fast_type(), CellsRef()),
@@ -33,5 +32,5 @@ MutableDenseTensorView::MutableDenseTensorView(ValueType type_in, CellsRef cells
{
}
-} // namespace tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h
index 7eee3a9483c..b68a1594905 100644
--- a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h
+++ b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h
@@ -5,8 +5,7 @@
#include "dense_tensor_view.h"
#include <cassert>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* A mutable view to a dense tensor where all dimensions are indexed.
@@ -18,7 +17,7 @@ private:
{
private:
eval::ValueType _type;
- std::vector<size_t *> _unboundDimSizes;
+ std::vector<eval::ValueType::Dimension::size_type *> _unboundDimSizes;
public:
MutableValueType(eval::ValueType type_in);
@@ -55,5 +54,5 @@ public:
}
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/direct_tensor_builder.h b/eval/src/vespa/eval/tensor/direct_tensor_builder.h
index 667cec7c7a9..1eb171eef6e 100644
--- a/eval/src/vespa/eval/tensor/direct_tensor_builder.h
+++ b/eval/src/vespa/eval/tensor/direct_tensor_builder.h
@@ -2,8 +2,7 @@
#pragma once
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* Forward declaration of utility class to build tensor of type TensorT,
@@ -11,5 +10,4 @@ namespace tensor {
*/
template <typename TensorT> class DirectTensorBuilder;
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h b/eval/src/vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h
index 33000d4889d..c977131fcd3 100644
--- a/eval/src/vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h
+++ b/eval/src/vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h
@@ -7,8 +7,7 @@
#include "sparse_tensor_address_builder.h"
#include "sparse_tensor_address_padder.h"
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* Utility class to build tensors of type SparseTensor, to be used by
@@ -91,9 +90,7 @@ public:
~DirectTensorBuilder() {}
Tensor::UP build() {
- return std::make_unique<SparseTensor>(std::move(_type),
- std::move(_cells),
- std::move(_stash));
+ return std::make_unique<SparseTensor>(std::move(_type), std::move(_cells), std::move(_stash));
}
template <class Function>
@@ -129,7 +126,7 @@ public:
eval::ValueType &fast_type() { return _type; }
Cells &cells() { return _cells; }
+ void reserve(uint32_t estimatedCells) { _cells.resize(estimatedCells*2); }
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
index 4762f1eceb4..1aa05bf4f61 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
@@ -12,8 +12,6 @@
#include <vespa/vespalib/stllike/hash_map.hpp>
#include <vespa/vespalib/stllike/hash_map_equal.hpp>
#include <vespa/vespalib/util/array_equal.hpp>
-#include <sstream>
-#include <algorithm>
using vespalib::eval::TensorSpec;
@@ -35,8 +33,7 @@ copyCells(Cells &cells, const Cells &cells_in, Stash &stash)
}
-SparseTensor::SparseTensor(const eval::ValueType &type_in,
- const Cells &cells_in)
+SparseTensor::SparseTensor(const eval::ValueType &type_in, const Cells &cells_in)
: _type(type_in),
_cells(),
_stash(STASH_CHUNK_SIZE)
@@ -45,14 +42,13 @@ SparseTensor::SparseTensor(const eval::ValueType &type_in,
}
-SparseTensor::SparseTensor(eval::ValueType &&type_in,
- Cells &&cells_in, Stash &&stash_in)
+SparseTensor::SparseTensor(eval::ValueType &&type_in, Cells &&cells_in, Stash &&stash_in)
: _type(std::move(type_in)),
_cells(std::move(cells_in)),
_stash(std::move(stash_in))
-{
-}
+{ }
+SparseTensor::~SparseTensor() = default;
bool
SparseTensor::operator==(const SparseTensor &rhs) const
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
index c7c38f0a182..2715e606729 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
@@ -11,8 +11,7 @@
#include <vespa/vespalib/stllike/string.h>
#include <vespa/vespalib/util/stash.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* A tensor implementation using serialized tensor addresses to
@@ -22,7 +21,7 @@ namespace tensor {
class SparseTensor : public Tensor
{
public:
- using Cells = vespalib::hash_map<SparseTensorAddressRef, double>;
+ using Cells = hash_map<SparseTensorAddressRef, double>;
static constexpr size_t STASH_CHUNK_SIZE = 16384u;
@@ -32,28 +31,23 @@ private:
Stash _stash;
public:
- explicit SparseTensor(const eval::ValueType &type_in,
- const Cells &cells_in);
- SparseTensor(eval::ValueType &&type_in,
- Cells &&cells_in, Stash &&stash_in);
+ explicit SparseTensor(const eval::ValueType &type_in, const Cells &cells_in);
+ SparseTensor(eval::ValueType &&type_in, Cells &&cells_in, Stash &&stash_in);
+ ~SparseTensor() override;
const Cells &cells() const { return _cells; }
const eval::ValueType &fast_type() const { return _type; }
bool operator==(const SparseTensor &rhs) const;
eval::ValueType combineDimensionsWith(const SparseTensor &rhs) const;
- virtual const eval::ValueType &type() const override;
- virtual double as_double() const override;
- virtual Tensor::UP apply(const CellFunction &func) const override;
- virtual Tensor::UP join(join_fun_t function,
- const Tensor &arg) const override;
- virtual Tensor::UP reduce(join_fun_t op,
- const std::vector<vespalib::string> &dimensions)
- const override;
- virtual bool equals(const Tensor &arg) const override;
- virtual Tensor::UP clone() const override;
- virtual eval::TensorSpec toSpec() const override;
- virtual void accept(TensorVisitor &visitor) const override;
+ const eval::ValueType &type() const override;
+ double as_double() const override;
+ Tensor::UP apply(const CellFunction &func) const override;
+ Tensor::UP join(join_fun_t function, const Tensor &arg) const override;
+ Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const override;
+ bool equals(const Tensor &arg) const override;
+ Tensor::UP clone() const override;
+ eval::TensorSpec toSpec() const override;
+ void accept(TensorVisitor &visitor) const override;
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_builder.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_builder.h
index e9a66eb4539..f74ce257b31 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_builder.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_builder.h
@@ -2,12 +2,10 @@
#pragma once
-#include <vespa/vespalib/stllike/string.h>
-#include <vector>
#include "sparse_tensor_address_ref.h"
+#include <vespa/vespalib/stllike/string.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
@@ -20,21 +18,26 @@ namespace tensor {
class SparseTensorAddressBuilder
{
private:
- std::vector<char> _address;
+ vespalib::Array<char> _address;
- void
- append(vespalib::stringref str)
- {
- const char *cstr = str.c_str();
- _address.insert(_address.end(), cstr, cstr + str.size() + 1);
+protected:
+ void append(vespalib::stringref str) {
+ for (size_t i(0); i < str.size() + 1; i++) {
+ _address.push_back_fast(str[i]);
+ }
+ }
+ void ensure_room(size_t additional) {
+ if (_address.capacity() < (_address.size() + additional)) {
+ _address.reserve(_address.size() + additional);
+ }
}
public:
- SparseTensorAddressBuilder()
- : _address()
- {
+ SparseTensorAddressBuilder() : _address() {}
+ void add(vespalib::stringref label) {
+ ensure_room(label.size()+1);
+ append(label);
}
- void add(vespalib::stringref label) { append(label); }
- void addUndefined() { _address.emplace_back('\0'); }
+ void addUndefined() { _address.push_back('\0'); }
void clear() { _address.clear(); }
SparseTensorAddressRef getAddressRef() const {
return SparseTensorAddressRef(&_address[0], _address.size());
@@ -42,6 +45,4 @@ public:
bool empty() const { return _address.empty(); }
};
-
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.cpp
index b386ec82528..e0de63b90d2 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.cpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.cpp
@@ -5,12 +5,9 @@
#include <vespa/eval/eval/value_type.h>
#include <cassert>
-namespace vespalib {
-namespace tensor {
-namespace sparse {
+namespace vespalib::tensor::sparse {
-TensorAddressCombiner::TensorAddressCombiner(const eval::ValueType &lhs,
- const eval::ValueType &rhs)
+TensorAddressCombiner::TensorAddressCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs)
{
auto rhsItr = rhs.dimensions().cbegin();
auto rhsItrEnd = rhs.dimensions().cend();
@@ -32,8 +29,17 @@ TensorAddressCombiner::TensorAddressCombiner(const eval::ValueType &lhs,
}
}
-TensorAddressCombiner::~TensorAddressCombiner()
-{
+TensorAddressCombiner::~TensorAddressCombiner() = default;
+
+size_t
+TensorAddressCombiner::numOverlappingDimensions() const {
+ size_t count = 0;
+ for (AddressOp op : _ops) {
+ if (op == AddressOp::BOTH) {
+ count++;
+ }
+ }
+ return count;
}
bool
@@ -41,15 +47,16 @@ TensorAddressCombiner::combine(SparseTensorAddressRef lhsRef,
SparseTensorAddressRef rhsRef)
{
clear();
+ ensure_room(lhsRef.size() + rhsRef.size());
SparseTensorAddressDecoder lhs(lhsRef);
SparseTensorAddressDecoder rhs(rhsRef);
for (auto op : _ops) {
switch (op) {
case AddressOp::LHS:
- add(lhs.decodeLabel());
+ append(lhs.decodeLabel());
break;
case AddressOp::RHS:
- add(rhs.decodeLabel());
+ append(rhs.decodeLabel());
break;
case AddressOp::BOTH:
auto lhsLabel(lhs.decodeLabel());
@@ -57,14 +64,10 @@ TensorAddressCombiner::combine(SparseTensorAddressRef lhsRef,
if (lhsLabel != rhsLabel) {
return false;
}
- add(lhsLabel);
+ append(lhsLabel);
}
}
- assert(!lhs.valid());
- assert(!rhs.valid());
return true;
}
-} // namespace vespalib::tensor::sparse
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h
index 402b4bc598a..1a7f2fd8d3c 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h
@@ -3,12 +3,11 @@
#pragma once
#include "sparse_tensor_address_builder.h"
-#include <vespa/eval/tensor/types.h>
-namespace vespalib {
-namespace eval { class ValueType; }
-namespace tensor {
-namespace sparse {
+#define VESPA_DLL_LOCAL __attribute__ ((visibility("hidden")))
+
+namespace vespalib::eval { class ValueType; }
+namespace vespalib::tensor::sparse {
/**
* Combine two tensor addresses to a new tensor address. Common dimensions
@@ -16,25 +15,17 @@ namespace sparse {
*/
class TensorAddressCombiner : public SparseTensorAddressBuilder
{
- enum class AddressOp
- {
- LHS,
- RHS,
- BOTH
- };
+ enum class AddressOp { LHS, RHS, BOTH };
std::vector<AddressOp> _ops;
-
public:
- TensorAddressCombiner(const eval::ValueType &lhs,
- const eval::ValueType &rhs);
-
+ TensorAddressCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs);
~TensorAddressCombiner();
- bool combine(SparseTensorAddressRef lhsRef, SparseTensorAddressRef rhsRef);
+ VESPA_DLL_LOCAL bool combine(SparseTensorAddressRef lhsRef, SparseTensorAddressRef rhsRef);
+ size_t numOverlappingDimensions() const;
+ size_t numDimensions() const { return _ops.size(); }
};
+}
-} // namespace vespalib::tensor::sparse
-} // namespace vespalib::tensor
-} // namespace vespalib
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_decoder.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_decoder.h
index 3a0502aee5b..2fbd9932009 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_decoder.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_decoder.h
@@ -5,10 +5,7 @@
#include <vespa/vespalib/stllike/string.h>
#include "sparse_tensor_address_ref.h"
-namespace vespalib {
-
-
-namespace tensor {
+namespace vespalib::tensor {
/**
* A decoder for a serialized tensor address, with only labels present.
@@ -40,5 +37,5 @@ public:
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_padder.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_padder.h
index 506f8b29593..29e10c778ba 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_padder.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_padder.h
@@ -6,8 +6,7 @@
#include "sparse_tensor_address_decoder.h"
#include <cassert>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
@@ -16,11 +15,7 @@ namespace tensor {
*/
class SparseTensorAddressPadder : public SparseTensorAddressBuilder
{
- enum class PadOp
- {
- PAD,
- COPY
- };
+ enum class PadOp { PAD, COPY };
std::vector<PadOp> _padOps;
@@ -67,6 +62,5 @@ public:
}
};
+}
-} // namespace vespalib::tensor
-} // namespace vespalib
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.cpp
index 7da5bd8d61a..fbd0034bc14 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.cpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.cpp
@@ -4,20 +4,16 @@
#include <vespa/eval/eval/value_type.h>
#include <vespa/vespalib/stllike/hash_set.hpp>
-namespace vespalib {
-namespace tensor {
-namespace sparse {
+namespace vespalib::tensor::sparse {
TensorAddressReducer::TensorAddressReducer(const eval::ValueType &type,
- const std::vector<vespalib::string> &
- removeDimensions)
+ const std::vector<vespalib::string> & removeDimensions)
: SparseTensorAddressBuilder(),
_ops()
{
- TensorDimensionsSet removeSet(removeDimensions.cbegin(),
- removeDimensions.cend());
+ TensorDimensionsSet removeSet(removeDimensions.cbegin(), removeDimensions.cend());
_ops.reserve(type.dimensions().size());
- for (auto &dim : type.dimensions()) {
+ for (const auto &dim : type.dimensions()) {
if (removeSet.find(dim.name) != removeSet.end()) {
_ops.push_back(AddressOp::REMOVE);
} else {
@@ -26,10 +22,7 @@ TensorAddressReducer::TensorAddressReducer(const eval::ValueType &type,
}
}
-TensorAddressReducer::~TensorAddressReducer()
-{
+TensorAddressReducer::~TensorAddressReducer() = default;
+
}
-} // namespace vespalib::tensor::sparse
-} // namespace vespalib::tensor
-} // namespace vespalib
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.h
index c40d34d9a53..a2034d3be49 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.h
@@ -7,21 +7,15 @@
#include "sparse_tensor_address_decoder.h"
#include <cassert>
-namespace vespalib {
-namespace eval { class ValueType; }
-namespace tensor {
-namespace sparse {
+namespace vespalib::eval { class ValueType; }
+namespace vespalib::tensor::sparse {
/**
* Reduce sparse tensor address by removing one or more dimensions.
*/
class TensorAddressReducer : public SparseTensorAddressBuilder
{
- enum AddressOp
- {
- REMOVE,
- COPY
- };
+ enum AddressOp { REMOVE, COPY };
using AddressOps = std::vector<AddressOp>;
@@ -50,7 +44,5 @@ public:
}
};
+}
-} // namespace vespalib::tensor::sparse
-} // namespace vespalib::tensor
-} // namespace vespalib
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_ref.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_ref.h
index 788bf1b8ddc..321690085be 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_ref.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_ref.h
@@ -2,9 +2,8 @@
#pragma once
-#include <vespa/vespalib/stllike/string.h>
-#include <vector>
#include <vespa/vespalib/util/stash.h>
+#include <cstring>
namespace vespalib {
@@ -19,15 +18,15 @@ namespace tensor {
class SparseTensorAddressRef
{
const void *_start;
- size_t _size;
- size_t _hash;
+ uint32_t _size;
+ uint32_t _hash;
public:
SparseTensorAddressRef()
: _start(nullptr), _size(0u), _hash(0u)
{
}
- SparseTensorAddressRef(const void *start_in, size_t size_in)
+ SparseTensorAddressRef(const void *start_in, uint32_t size_in)
: _start(start_in), _size(size_in),
_hash(calcHash())
{
@@ -43,9 +42,9 @@ public:
_start = res;
}
- size_t hash() const { return _hash; }
+ uint32_t hash() const { return _hash; }
- size_t calcHash() const { return hashValue(_start, _size); }
+ uint32_t calcHash() const { return hashValue(_start, _size); }
bool operator<(const SparseTensorAddressRef &rhs) const {
size_t minSize = std::min(_size, rhs._size);
@@ -65,7 +64,7 @@ public:
}
const void *start() const { return _start; }
- size_t size() const { return _size; }
+ uint32_t size() const { return _size; }
};
} // namespace vespalib::tensor
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.h
index 65d05bd4ba2..ec6edf2d847 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.h
@@ -2,11 +2,12 @@
#pragma once
-namespace vespalib {
-namespace tensor {
-class Tensor;
-class SparseTensor;
-namespace sparse {
+namespace vespalib::tensor {
+ class Tensor;
+ class SparseTensor;
+}
+
+namespace vespalib::tensor::sparse {
/**
* Create new tensor using all combinations of input tensor cells with matching
@@ -17,7 +18,5 @@ template <typename Function>
std::unique_ptr<Tensor>
apply(const SparseTensor &lhs, const SparseTensor &rhs, Function &&func);
+}
-} // namespace vespalib::tensor::sparse
-} // namespace vespalib::tensor
-} // namespace vespalib
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.hpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.hpp
index 4528c8ef1df..2027e0afc9d 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.hpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.hpp
@@ -7,9 +7,7 @@
#include <vespa/eval/tensor/direct_tensor_builder.h>
#include "direct_sparse_tensor_builder.h"
-namespace vespalib {
-namespace tensor {
-namespace sparse {
+namespace vespalib::tensor::sparse {
template <typename Function>
std::unique_ptr<Tensor>
@@ -17,10 +15,14 @@ apply(const SparseTensor &lhs, const SparseTensor &rhs, Function &&func)
{
DirectTensorBuilder<SparseTensor> builder(lhs.combineDimensionsWith(rhs));
TensorAddressCombiner addressCombiner(lhs.fast_type(), rhs.fast_type());
+ size_t estimatedCells = (lhs.cells().size() * rhs.cells().size());
+ if (addressCombiner.numOverlappingDimensions() != 0) {
+ estimatedCells = std::min(lhs.cells().size(), rhs.cells().size());
+ }
+ builder.reserve(estimatedCells*2);
for (const auto &lhsCell : lhs.cells()) {
for (const auto &rhsCell : rhs.cells()) {
- bool combineSuccess = addressCombiner.combine(lhsCell.first,
- rhsCell.first);
+ bool combineSuccess = addressCombiner.combine(lhsCell.first, rhsCell.first);
if (combineSuccess) {
builder.insertCell(addressCombiner.getAddressRef(),
func(lhsCell.second, rhsCell.second));
@@ -30,6 +32,4 @@ apply(const SparseTensor &lhs, const SparseTensor &rhs, Function &&func)
return builder.build();
}
-} // namespace vespalib::tensor::sparse
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.cpp
index dacf0c27593..9c3b13f6260 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.cpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.cpp
@@ -3,8 +3,7 @@
#include "sparse_tensor_builder.h"
#include <cassert>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
SparseTensorBuilder::SparseTensorBuilder()
: TensorBuilder(),
@@ -19,10 +18,7 @@ SparseTensorBuilder::SparseTensorBuilder()
{
}
-SparseTensorBuilder::~SparseTensorBuilder()
-{
-}
-
+SparseTensorBuilder::~SparseTensorBuilder() = default;
void
SparseTensorBuilder::makeType()
@@ -103,6 +99,5 @@ SparseTensorBuilder::build()
return ret;
}
+}
-} // namespace vespalib::tensor
-} // namespace vespalib
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.h
index af1566d46c5..ea5f607ff7e 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.h
@@ -10,8 +10,7 @@
#include <vespa/vespalib/stllike/hash_map.h>
#include <vespa/vespalib/util/stash.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* A builder of sparse tensors.
@@ -30,17 +29,13 @@ class SparseTensorBuilder : public TensorBuilder
void makeType();
public:
SparseTensorBuilder();
- virtual ~SparseTensorBuilder();
+ ~SparseTensorBuilder() override;
- virtual Dimension
- define_dimension(const vespalib::string &dimension) override;
- virtual TensorBuilder &
- add_label(Dimension dimension,
- const vespalib::string &label) override;
- virtual TensorBuilder &add_cell(double value) override;
-
- virtual Tensor::UP build() override;
+ Dimension define_dimension(const vespalib::string &dimension) override;
+ TensorBuilder & add_label(Dimension dimension, const vespalib::string &label) override;
+ TensorBuilder &add_cell(double value) override;
+ Tensor::UP build() override;
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
+
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.cpp
index b4c9d511d09..cd5715e7379 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.cpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.cpp
@@ -1,9 +1,9 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "sparse_tensor_match.h"
+#include <vespa/vespalib/stllike/hash_map.hpp>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
namespace {
@@ -73,9 +73,9 @@ transformAddress(SparseTensorAddressBuilder &builder,
void
-SparseTensorMatch::fastMatch(const TensorImplType &lhs,
- const TensorImplType &rhs)
+SparseTensorMatch::fastMatch(const TensorImplType &lhs, const TensorImplType &rhs)
{
+ _builder.reserve(lhs.cells().size());
for (const auto &lhsCell : lhs.cells()) {
auto rhsItr = rhs.cells().find(lhsCell.first);
if (rhsItr != rhs.cells().end()) {
@@ -85,13 +85,11 @@ SparseTensorMatch::fastMatch(const TensorImplType &lhs,
}
void
-SparseTensorMatch::slowMatch(const TensorImplType &lhs,
- const TensorImplType &rhs)
+SparseTensorMatch::slowMatch(const TensorImplType &lhs, const TensorImplType &rhs)
{
std::vector<AddressOp> ops;
SparseTensorAddressBuilder addressBuilder;
- SparseTensorAddressPadder addressPadder(_builder.fast_type(),
- lhs.fast_type());
+ SparseTensorAddressPadder addressPadder(_builder.fast_type(), lhs.fast_type());
buildTransformOps(ops, lhs.fast_type(), rhs.fast_type());
for (const auto &lhsCell : lhs.cells()) {
if (!transformAddress(addressBuilder, lhsCell.first, ops)) {
@@ -106,8 +104,7 @@ SparseTensorMatch::slowMatch(const TensorImplType &lhs,
}
}
-SparseTensorMatch::SparseTensorMatch(const TensorImplType &lhs,
- const TensorImplType &rhs)
+SparseTensorMatch::SparseTensorMatch(const TensorImplType &lhs, const TensorImplType &rhs)
: Parent(lhs.combineDimensionsWith(rhs))
{
if ((lhs.fast_type().dimensions().size() == rhs.fast_type().dimensions().size()) &&
@@ -123,6 +120,4 @@ SparseTensorMatch::SparseTensorMatch(const TensorImplType &lhs,
}
}
-
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.h
index d88386ec508..bb2c82a6d00 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.h
@@ -4,8 +4,7 @@
#include <vespa/eval/tensor/tensor_operation.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* Returns the match product of two tensors.
@@ -27,5 +26,4 @@ public:
SparseTensorMatch(const TensorImplType &lhs, const TensorImplType &rhs);
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_reduce.hpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_reduce.hpp
index 53ab8116255..8a43c6b52bd 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_reduce.hpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_reduce.hpp
@@ -6,9 +6,7 @@
#include <vespa/eval/tensor/direct_tensor_builder.h>
#include "direct_sparse_tensor_builder.h"
-namespace vespalib {
-namespace tensor {
-namespace sparse {
+namespace vespalib::tensor::sparse {
template <typename Function>
std::unique_ptr<Tensor>
@@ -50,6 +48,7 @@ reduce(const SparseTensor &tensor,
return reduceAll(tensor, builder, func);
}
TensorAddressReducer addressReducer(tensor.fast_type(), dimensions);
+ builder.reserve(tensor.cells().size()*2);
for (const auto &cell : tensor.cells()) {
addressReducer.reduce(cell.first);
builder.insertCell(addressReducer.getAddressRef(), cell.second, func);
@@ -57,6 +56,4 @@ reduce(const SparseTensor &tensor,
return builder.build();
}
-} // namespace vespalib::tensor::sparse
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.cpp
index 1e112cbaa6e..866956dd23e 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.cpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.cpp
@@ -14,12 +14,10 @@ SparseTensorUnsortedAddressBuilder::SparseTensorUnsortedAddressBuilder()
{
}
-SparseTensorUnsortedAddressBuilder::~SparseTensorUnsortedAddressBuilder() {
-}
+SparseTensorUnsortedAddressBuilder::~SparseTensorUnsortedAddressBuilder() = default;
void
-SparseTensorUnsortedAddressBuilder::buildTo(SparseTensorAddressBuilder &
- builder,
+SparseTensorUnsortedAddressBuilder::buildTo(SparseTensorAddressBuilder & builder,
const eval::ValueType &type)
{
const char *base = &_elementStrings[0];
@@ -47,3 +45,4 @@ SparseTensorUnsortedAddressBuilder::buildTo(SparseTensorAddressBuilder &
}
}
+
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.h
index 24519e924d9..681bdabc5eb 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.h
@@ -6,9 +6,8 @@
#include <vector>
#include <vespa/eval/tensor/types.h>
-namespace vespalib {
-namespace eval { class ValueType; }
-namespace tensor {
+namespace vespalib::eval { class ValueType; }
+namespace vespalib::tensor {
class SparseTensorAddressBuilder;
@@ -73,11 +72,9 @@ public:
* Sort the stored tensor address and pass it over to a strict
* tensor address builder in sorted order.
*/
- void buildTo(SparseTensorAddressBuilder &builder,
- const eval::ValueType &type);
+ void buildTo(SparseTensorAddressBuilder &builder, const eval::ValueType &type);
void clear() { _elementStrings.clear(); _elements.clear(); }
};
+}
-} // namespace vespalib::tensor
-} // namespace vespalib
diff --git a/eval/src/vespa/eval/tensor/tensor_address.h b/eval/src/vespa/eval/tensor/tensor_address.h
index 74b2aff5561..c8c60ef6fa6 100644
--- a/eval/src/vespa/eval/tensor/tensor_address.h
+++ b/eval/src/vespa/eval/tensor/tensor_address.h
@@ -8,8 +8,7 @@
#include <map>
#include <vector>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* A sparse immutable address to a tensor cell.
@@ -87,5 +86,4 @@ public:
std::ostream &operator<<(std::ostream &out, const TensorAddress::Elements &elements);
std::ostream &operator<<(std::ostream &out, const TensorAddress &value);
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/tensor_address_builder.h b/eval/src/vespa/eval/tensor/tensor_address_builder.h
index 40b784e051a..47ea79fd985 100644
--- a/eval/src/vespa/eval/tensor/tensor_address_builder.h
+++ b/eval/src/vespa/eval/tensor/tensor_address_builder.h
@@ -4,8 +4,7 @@
#include "tensor_address.h"
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
@@ -27,5 +26,4 @@ public:
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/tensor_apply.cpp b/eval/src/vespa/eval/tensor/tensor_apply.cpp
index 7c518d0516f..8f0610fed65 100644
--- a/eval/src/vespa/eval/tensor/tensor_apply.cpp
+++ b/eval/src/vespa/eval/tensor/tensor_apply.cpp
@@ -1,9 +1,9 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "tensor_apply.h"
+#include <vespa/vespalib/stllike/hash_map.hpp>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
template <class TensorT>
TensorApply<TensorT>::TensorApply(const TensorImplType &tensor,
@@ -17,5 +17,4 @@ TensorApply<TensorT>::TensorApply(const TensorImplType &tensor,
template class TensorApply<SparseTensor>;
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/tensor_apply.h b/eval/src/vespa/eval/tensor/tensor_apply.h
index bd675e7ec58..bb5ffdd1885 100644
--- a/eval/src/vespa/eval/tensor/tensor_apply.h
+++ b/eval/src/vespa/eval/tensor/tensor_apply.h
@@ -5,8 +5,7 @@
#include "cell_function.h"
#include "tensor_operation.h"
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* Returns a tensor with the given function applied to all cells in the input tensor.
@@ -23,5 +22,4 @@ public:
extern template class TensorApply<SparseTensor>;
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/tensor_mapper.cpp b/eval/src/vespa/eval/tensor/tensor_mapper.cpp
index 25b369c246d..f1039b08816 100644
--- a/eval/src/vespa/eval/tensor/tensor_mapper.cpp
+++ b/eval/src/vespa/eval/tensor/tensor_mapper.cpp
@@ -8,6 +8,7 @@
#include "wrapped_simple_tensor.h"
#include <vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h>
#include <vespa/eval/tensor/dense/dense_tensor.h>
+#include <vespa/vespalib/stllike/hash_map.hpp>
#include <limits>
using vespalib::eval::ValueType;
diff --git a/eval/src/vespa/eval/tensor/tensor_operation.h b/eval/src/vespa/eval/tensor/tensor_operation.h
index 6975c21c448..827c16573d5 100644
--- a/eval/src/vespa/eval/tensor/tensor_operation.h
+++ b/eval/src/vespa/eval/tensor/tensor_operation.h
@@ -5,8 +5,7 @@
#include "direct_tensor_builder.h"
#include <vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* Base class for an operation over tensors.
@@ -46,5 +45,4 @@ public:
}
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/tensor_visitor.h b/eval/src/vespa/eval/tensor/tensor_visitor.h
index 4002aab6e7e..4cd9792afbd 100644
--- a/eval/src/vespa/eval/tensor/tensor_visitor.h
+++ b/eval/src/vespa/eval/tensor/tensor_visitor.h
@@ -6,8 +6,7 @@
#include <vespa/vespalib/stllike/string.h>
#include "types.h"
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* Class for visiting a tensor. First visit must specify dimensions,
@@ -20,5 +19,4 @@ public:
virtual void visit(const TensorAddress &address, double value) = 0;
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+} \ No newline at end of file
diff --git a/eval/src/vespa/eval/tensor/types.h b/eval/src/vespa/eval/tensor/types.h
index aa5d8c89707..d969bc0a2fb 100644
--- a/eval/src/vespa/eval/tensor/types.h
+++ b/eval/src/vespa/eval/tensor/types.h
@@ -7,13 +7,11 @@
#include <vector>
#include <map>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
using TensorCells = std::map<std::map<vespalib::string, vespalib::string>, double>;
using TensorDimensions = std::vector<vespalib::string>;
using TensorDimensionsSet = vespalib::hash_set<vespalib::string>;
using DenseTensorCells = std::map<std::map<vespalib::string, size_t>, double>;
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java
index 2e58455bc39..b2d1af15867 100644
--- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java
+++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java
@@ -80,13 +80,13 @@ public class FileDistributionRpcServer {
try {
if (pathToFile.isPresent()) {
req.returnValues().add(new StringValue(pathToFile.get().getAbsolutePath()));
- log.log(LogLevel.INFO, "File reference '" + fileReference.value() + "' available at " + pathToFile.get());
+ log.log(LogLevel.DEBUG, "File reference '" + fileReference.value() + "' available at " + pathToFile.get());
} else {
log.log(LogLevel.INFO, "File reference '" + fileReference.value() + "' not found, returning error");
req.setError(fileReferenceDoesNotExists, "File reference '" + fileReference.value() + "' not found");
}
} catch (Throwable e) {
- log.log(LogLevel.WARNING, "File reference '" + fileReference.value() + "' got exeption: " + e.getMessage());
+ log.log(LogLevel.WARNING, "File reference '" + fileReference.value() + "' got exception: " + e.getMessage());
req.setError(fileReferenceInternalError, "File reference '" + fileReference.value() + "' removed");
}
req.returnRequest();
@@ -123,5 +123,4 @@ public class FileDistributionRpcServer {
req.returnValues().add(new Int32Value(0));
}
-
}
diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java
index 727786cdc78..5de006cd17c 100644
--- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java
+++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java
@@ -107,7 +107,7 @@ public class FileDownloader {
} else if (!file.canRead()) {
throw new RuntimeException("File with reference '" + fileReference.value() + "'exists, but unable to read it");
} else {
- fileReferenceDownloader.setDownloadStatus(fileReference.value(), 100.0);
+ fileReferenceDownloader.setDownloadStatus(fileReference, 100.0);
return Optional.of(file);
}
}
diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReceiver.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReceiver.java
index d57ce4ca5de..d9d1b4984eb 100644
--- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReceiver.java
+++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReceiver.java
@@ -85,7 +85,7 @@ public class FileReceiver {
try {
inprogressFile = Files.createTempFile(tmpDirectory.toPath(), fileName, ".inprogress").toFile();
} catch (IOException e) {
- String msg = "Failed creating tempfile for inprogress file for(" + fileName + ") in '" + fileReferenceDir.toPath() + "': ";
+ String msg = "Failed creating temp file for inprogress file for(" + fileName + ") in '" + fileReferenceDir.toPath() + "': ";
log.log(LogLevel.ERROR, msg + e.getMessage(), e);
throw new RuntimeException(msg, e);
}
@@ -103,6 +103,7 @@ public class FileReceiver {
Files.write(inprogressFile.toPath(), part, StandardOpenOption.WRITE, StandardOpenOption.APPEND);
} catch (IOException e) {
log.log(LogLevel.ERROR, "Failed writing to file(" + inprogressFile.toPath() + "): " + e.getMessage(), e);
+ inprogressFile.delete();
throw new RuntimeException("Failed writing to file(" + inprogressFile.toPath() + "): ", e);
}
currentFileSize += part.length;
@@ -247,8 +248,11 @@ public class FileReceiver {
log.log(LogLevel.DEBUG, "File moved from " + tempFile.getAbsolutePath()+ " to " + destination.getAbsolutePath());
} catch (FileAlreadyExistsException e) {
// Don't fail if it already exists (we might get the file from several config servers when retrying, servers are down etc.
- // so it might be written already)
+ // so it might be written already). Delete temp file in that case, to avoid filling the disk.
log.log(LogLevel.DEBUG, "File '" + destination.getAbsolutePath() + "' already exists, continuing: " + e.getMessage());
+ try {
+ Files.delete(tempFile.toPath());
+ } catch (IOException ioe) { /* ignore failure */}
} catch (IOException e) {
String message = "Failed moving file '" + tempFile.getAbsolutePath() + "' to '" + destination.getAbsolutePath() + "'";
log.log(LogLevel.ERROR, message, e);
@@ -295,7 +299,7 @@ public class FileReceiver {
try {
session.addPart(partId, part);
} catch (Exception e) {
- log.severe("Got exception + " + e);
+ log.severe("Got exception " + e);
retval = 1;
}
double completeness = (double) session.currentFileSize / (double) session.fileSize;
diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java
index 509231ba7ff..031506487a8 100644
--- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java
+++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java
@@ -65,7 +65,7 @@ public class FileReferenceDownloader {
Thread.sleep(10);
}
}
- catch (InterruptedException e) {}
+ catch (InterruptedException e) { /* ignored */}
}
if ( !downloadStarted) {
@@ -107,7 +107,7 @@ public class FileReferenceDownloader {
if (validateResponse(request)) {
log.log(LogLevel.DEBUG, "Request callback, OK. Req: " + request + "\nSpec: " + connection);
if (request.returnValues().get(0).asInt32() == 0) {
- log.log(LogLevel.INFO, "Found file reference '" + fileReference.value() + "' available at " + connection.getAddress());
+ log.log(LogLevel.DEBUG, "Found file reference '" + fileReference.value() + "' available at " + connection.getAddress());
return true;
} else {
log.log(LogLevel.INFO, "File reference '" + fileReference.value() + "' not found for " + connection.getAddress());
@@ -169,10 +169,6 @@ public class FileReferenceDownloader {
return status;
}
- void setDownloadStatus(String file, double completeness) {
- setDownloadStatus(new FileReference(file), completeness);
- }
-
void setDownloadStatus(FileReference fileReference, double completeness) {
synchronized (downloads) {
downloadStatus.put(fileReference, completeness);
diff --git a/jaxrs_client_utils/src/main/java/com/yahoo/vespa/jaxrs/client/JerseyJaxRsClientFactory.java b/jaxrs_client_utils/src/main/java/com/yahoo/vespa/jaxrs/client/JerseyJaxRsClientFactory.java
index 0b3c7ad8d3b..5a3ccfed490 100644
--- a/jaxrs_client_utils/src/main/java/com/yahoo/vespa/jaxrs/client/JerseyJaxRsClientFactory.java
+++ b/jaxrs_client_utils/src/main/java/com/yahoo/vespa/jaxrs/client/JerseyJaxRsClientFactory.java
@@ -6,6 +6,7 @@ import org.glassfish.jersey.client.ClientProperties;
import org.glassfish.jersey.client.HttpUrlConnectorProvider;
import org.glassfish.jersey.client.proxy.WebResourceFactory;
+import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.ClientRequestFilter;
@@ -25,23 +26,26 @@ public class JerseyJaxRsClientFactory implements JaxRsClientFactory {
private final int readTimeoutMs;
private final SSLContext sslContext;
private final String userAgent;
+ private final HostnameVerifier hostnameVerifier;
public JerseyJaxRsClientFactory() {
this(DEFAULT_CONNECT_TIMEOUT_MS, DEFAULT_READ_TIMEOUT_MS);
}
- public JerseyJaxRsClientFactory(SSLContext sslContext, String userAgent) {
- this(DEFAULT_CONNECT_TIMEOUT_MS, DEFAULT_READ_TIMEOUT_MS, sslContext, userAgent);
+ public JerseyJaxRsClientFactory(SSLContext sslContext, HostnameVerifier hostnameVerifier, String userAgent) {
+ this(DEFAULT_CONNECT_TIMEOUT_MS, DEFAULT_READ_TIMEOUT_MS, sslContext, hostnameVerifier, userAgent);
}
public JerseyJaxRsClientFactory(final int connectTimeoutMs, final int readTimeoutMs) {
- this(connectTimeoutMs, readTimeoutMs, null, null);
+ this(connectTimeoutMs, readTimeoutMs, null, null, null);
}
- public JerseyJaxRsClientFactory(int connectTimeoutMs, int readTimeoutMs, SSLContext sslContext, String userAgent) {
+ public JerseyJaxRsClientFactory(int connectTimeoutMs, int readTimeoutMs, SSLContext sslContext,
+ HostnameVerifier hostnameVerifier, String userAgent) {
this.connectTimeoutMs = connectTimeoutMs;
this.readTimeoutMs = readTimeoutMs;
this.sslContext = sslContext;
+ this.hostnameVerifier = hostnameVerifier;
this.userAgent = userAgent;
}
@@ -61,7 +65,9 @@ public class JerseyJaxRsClientFactory implements JaxRsClientFactory {
.property(ClientProperties.FOLLOW_REDIRECTS, true);
if (sslContext != null) {
builder.sslContext(sslContext);
- builder.hostnameVerifier((s, sslSession) -> true);
+ }
+ if (hostnameVerifier != null) {
+ builder.hostnameVerifier(hostnameVerifier);
}
if (userAgent != null) {
builder.register((ClientRequestFilter) context ->
diff --git a/jdisc_core/src/main/java/com/yahoo/jdisc/core/StandaloneMain.java b/jdisc_core/src/main/java/com/yahoo/jdisc/core/StandaloneMain.java
index a78e4f1af40..1291418083b 100644
--- a/jdisc_core/src/main/java/com/yahoo/jdisc/core/StandaloneMain.java
+++ b/jdisc_core/src/main/java/com/yahoo/jdisc/core/StandaloneMain.java
@@ -36,12 +36,15 @@ public class StandaloneMain {
void run(String bundleLocation) {
try {
+ // We're not logging at this point since the application is responsible
+ // for setting up logging.
System.out.println("debug\tInitializing application without privileges.");
loader.init(bundleLocation, false);
loader.start();
setupSigTermHandler();
waitForShutdown();
System.out.println("debug\tTrying to shutdown in a controlled manner.");
+ log.log(Level.INFO, "JDisc shutting down");
loader.stop();
System.out.println("debug\tTrying to clean up in a controlled manner.");
loader.destroy();
@@ -50,7 +53,7 @@ public class StandaloneMain {
} catch (Throwable e) {
System.out.print("debug\tUnexpected: ");
e.printStackTrace();
- log.log(Level.SEVERE, "Unexpected: ", e);
+ log.log(Level.SEVERE, "JDisc exiting: Throwable caught: ", e);
System.exit(6);
}
}
diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java
index 5cabe8acd27..31268c823ba 100644
--- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java
+++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java
@@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
@@ -34,6 +35,7 @@ import static com.yahoo.jdisc.http.server.jetty.Exceptions.throwUnchecked;
/**
* @author Simon Thoresen Hult
+ * @author bjorncs
*/
class HttpRequestDispatch {
@@ -123,10 +125,10 @@ class HttpRequestDispatch {
boolean reportedError = false;
if (error != null) {
- if (error instanceof EofException) {
+ if (error instanceof CompletionException && error.getCause() instanceof EofException) {
log.log(Level.FINE,
- "Network connection was unexpectedly terminated: " + parent.servletRequest.getRequestURI(),
- error);
+ error,
+ () -> "Network connection was unexpectedly terminated: " + parent.servletRequest.getRequestURI());
} else if (!(error instanceof OverloadException || error instanceof BindingNotFoundException)) {
log.log(Level.WARNING, "Request failed: " + parent.servletRequest.getRequestURI(), error);
}
diff --git a/metrics/src/vespa/metrics/metrictimer.h b/metrics/src/vespa/metrics/metrictimer.h
index 096ba3e27af..0282c0f17ad 100644
--- a/metrics/src/vespa/metrics/metrictimer.h
+++ b/metrics/src/vespa/metrics/metrictimer.h
@@ -8,7 +8,7 @@
#pragma once
-#include <vespa/metrics/valuemetric.h>
+#include "valuemetric.h"
#include <chrono>
namespace metrics {
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java
index 2d22f4c4ccf..868ebf39f70 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java
@@ -179,9 +179,10 @@ public class NodeAdminImpl implements NodeAdmin {
}
}, 0, 55, TimeUnit.SECONDS);
+ int delay = 120; // WARNING: Reducing this will increase the load on config servers.
aclScheduler.scheduleWithFixedDelay(() -> {
if (!isFrozen()) aclMaintainer.run();
- }, 30, 60, TimeUnit.SECONDS);
+ }, 30, delay, TimeUnit.SECONDS);
}
@Override
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminProvider.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminProvider.java
index a343e431b5a..3777d7e20d1 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminProvider.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminProvider.java
@@ -36,7 +36,8 @@ import java.util.function.Function;
*/
public class NodeAdminProvider implements Provider<NodeAdminStateUpdater> {
- private static final Duration NODE_AGENT_SCAN_INTERVAL = Duration.ofSeconds(30);
+ // WARNING: reducing the node agent interval will increase the load on the config servers
+ private static final Duration NODE_AGENT_SCAN_INTERVAL = Duration.ofSeconds(60);
private static final Duration NODE_ADMIN_CONVERGE_STATE_INTERVAL = Duration.ofSeconds(30);
private final NodeAdminStateUpdater nodeAdminStateUpdater;
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutor.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutor.java
index 3576f37eb9a..94ad94d9a65 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutor.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutor.java
@@ -43,7 +43,6 @@ import java.util.Optional;
*/
public class ConfigServerHttpRequestExecutor {
private static final PrefixLogger NODE_ADMIN_LOGGER = PrefixLogger.getNodeAdminLogger(ConfigServerHttpRequestExecutor.class);
- private static final int MAX_LOOPS = 2;
private final ObjectMapper mapper = new ObjectMapper();
private final CloseableHttpClient client;
@@ -108,43 +107,41 @@ public class ConfigServerHttpRequestExecutor {
private <T> T tryAllConfigServers(CreateRequest requestFactory, Class<T> wantedReturnType) {
Exception lastException = null;
- for (int loopRetry = 0; loopRetry < MAX_LOOPS; loopRetry++) {
- for (URI configServer : configServerHosts) {
- final CloseableHttpResponse response;
- try {
- response = client.execute(requestFactory.createRequest(configServer));
- } catch (Exception e) {
- // Failure to communicate with a config server is not abnormal, as they are
- // upgraded at the same time as Docker hosts.
- if (e.getMessage().indexOf("(Connection refused)") > 0) {
- NODE_ADMIN_LOGGER.info("Connection refused to " + configServer + " (upgrading?), will try next");
- } else {
- NODE_ADMIN_LOGGER.warning("Failed to communicate with " + configServer + ", will try next: " + e.getMessage());
- }
- lastException = e;
+ for (URI configServer : configServerHosts) {
+ final CloseableHttpResponse response;
+ try {
+ response = client.execute(requestFactory.createRequest(configServer));
+ } catch (Exception e) {
+ // Failure to communicate with a config server is not abnormal, as they are
+ // upgraded at the same time as Docker hosts.
+ if (e.getMessage().indexOf("(Connection refused)") > 0) {
+ NODE_ADMIN_LOGGER.info("Connection refused to " + configServer + " (upgrading?), will try next");
+ } else {
+ NODE_ADMIN_LOGGER.warning("Failed to communicate with " + configServer + ", will try next: " + e.getMessage());
+ }
+ lastException = e;
+ continue;
+ }
+
+ try {
+ Optional<HttpException> retryableException = HttpException.handleStatusCode(
+ response.getStatusLine().getStatusCode(),
+ "Config server " + configServer);
+ if (retryableException.isPresent()) {
+ lastException = retryableException.get();
continue;
}
try {
- Optional<HttpException> retryableException = HttpException.handleStatusCode(
- response.getStatusLine().getStatusCode(),
- "Config server " + configServer);
- if (retryableException.isPresent()) {
- lastException = retryableException.get();
- continue;
- }
-
- try {
- return mapper.readValue(response.getEntity().getContent(), wantedReturnType);
- } catch (IOException e) {
- throw new RuntimeException("Response didn't contain nodes element, failed parsing?", e);
- }
- } finally {
- try {
- response.close();
- } catch (IOException e) {
- NODE_ADMIN_LOGGER.warning("Ignoring exception from closing response", e);
- }
+ return mapper.readValue(response.getEntity().getContent(), wantedReturnType);
+ } catch (IOException e) {
+ throw new RuntimeException("Response didn't contain nodes element, failed parsing?", e);
+ }
+ } finally {
+ try {
+ response.close();
+ } catch (IOException e) {
+ NODE_ADMIN_LOGGER.warning("Ignoring exception from closing response", e);
}
}
}
diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutorTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutorTest.java
index 67cd2c79034..799f8a72fd9 100644
--- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutorTest.java
+++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutorTest.java
@@ -117,8 +117,7 @@ public class ConfigServerHttpRequestExecutorTest {
}
String[] log = mockLog.toString().split(" ");
- assertThat(log, arrayContainingInAnyOrder("GET http://host1:666/path", "GET http://host2:666/path",
- "GET http://host1:666/path", "GET http://host2:666/path"));
+ assertThat(log, arrayContainingInAnyOrder("GET http://host1:666/path", "GET http://host2:666/path"));
}
@Test
@@ -134,7 +133,6 @@ public class ConfigServerHttpRequestExecutorTest {
String[] log = mockLog.toString().split(" ");
assertThat(log, arrayContainingInAnyOrder(
- "GET http://host1:666/path", "GET http://host2:666/path",
"GET http://host1:666/path", "GET http://host2:666/path"));
}
diff --git a/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoreCollector.java b/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoreCollector.java
index 4681010940c..de08bdbe107 100644
--- a/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoreCollector.java
+++ b/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoreCollector.java
@@ -2,7 +2,6 @@
package com.yahoo.vespa.hosted.node.maintainer;
import com.yahoo.collections.Pair;
-import static com.yahoo.vespa.defaults.Defaults.getDefaults;
import com.yahoo.system.ProcessExecuter;
import java.io.IOException;
@@ -19,6 +18,8 @@ import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
+import static com.yahoo.vespa.defaults.Defaults.getDefaults;
+
/**
* Takes in a compressed (lz4) or uncompressed core dump and collects relevant metadata.
*
@@ -166,7 +167,7 @@ public class CoreCollector {
Path decompressedPath = Paths.get(coredumpPath.toString().replaceFirst("\\.lz4$", ""));
Pair<Integer, String> result = processExecuter.exec(
- new String[]{LZ4_PATH, "-d", coredumpPath.toString(), decompressedPath.toString()});
+ new String[]{LZ4_PATH, "-f", "-d", coredumpPath.toString(), decompressedPath.toString()});
if (result.getFirst() != 0) {
throw new RuntimeException("Failed to decompress file " + coredumpPath + ": " + result);
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java
index 7ef609d6311..62b00f914a3 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java
@@ -50,6 +50,10 @@ public class NodeRepositoryProvisioner implements Provisioner {
private final Activator activator;
private final BiConsumer<List<Node>, String> debugRecorder;
+ int getSpareCapacityProd() {
+ return SPARE_CAPACITY_PROD;
+ }
+
@Inject
public NodeRepositoryProvisioner(NodeRepository nodeRepository, NodeFlavors flavors, Zone zone) {
this(nodeRepository, flavors, zone, Clock.systemUTC(), (x, y) -> {});
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeSpec.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeSpec.java
index 4c6f1b022a4..23a6e3a8b9a 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeSpec.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeSpec.java
@@ -11,14 +11,14 @@ import java.util.Objects;
* A specification of a set of nodes.
* This reflects that nodes can be requested either by count and flavor or by type,
* and encapsulates the differences in logic between these two cases.
- *
+ *
* @author bratseth
*/
public interface NodeSpec {
/** The node type this requests */
NodeType type();
-
+
/** Returns whether the given flavor is compatible with this spec */
boolean isCompatible(Flavor flavor);
@@ -33,15 +33,15 @@ public interface NodeSpec {
/** Returns whether the given node count is sufficient to fulfill this spec */
boolean fulfilledBy(int count);
-
+
/** Returns the amount the given count is above the minimum amount needed to fulfill this request */
int surplusGiven(int count);
-
+
/** Returns a specification of a fraction of all the nodes of this. It is assumed the argument is a valid divisor. */
NodeSpec fraction(int divisor);
- /**
- * Assigns the flavor requested by this to the given node and returns it,
+ /**
+ * Assigns the flavor requested by this to the given node and returns it,
* if one is requested and it is allowed to change.
* Otherwise, the node is returned unchanged.
*/
@@ -50,17 +50,17 @@ public interface NodeSpec {
static NodeSpec from(int nodeCount, Flavor flavor) {
return new CountNodeSpec(nodeCount, flavor);
}
-
+
static NodeSpec from(NodeType type) {
return new TypeNodeSpec(type);
}
-
+
/** A node spec specifying a node count and a flavor */
class CountNodeSpec implements NodeSpec {
-
+
private final int count;
private final Flavor requestedFlavor;
-
+
public CountNodeSpec(int count, Flavor flavor) {
Objects.requireNonNull(flavor, "A flavor must be specified");
this.count = count;
@@ -79,7 +79,7 @@ public interface NodeSpec {
public NodeType type() { return NodeType.tenant; }
@Override
- public boolean isCompatible(Flavor flavor) {
+ public boolean isCompatible(Flavor flavor) {
if (flavor.satisfies(requestedFlavor)) return true;
return requestedFlavorCanBeAchievedByResizing(flavor);
}
@@ -91,7 +91,7 @@ public interface NodeSpec {
public boolean specifiesNonStockFlavor() { return ! requestedFlavor.isStock(); }
@Override
- public boolean fulfilledBy(int count) { return count >= this.count; }
+ public boolean fulfilledBy(int count) { return count >= this.count; }
@Override
public boolean saturatedBy(int count) { return fulfilledBy(count); } // min=max for count specs
@@ -101,12 +101,13 @@ public interface NodeSpec {
@Override
public NodeSpec fraction(int divisor) { return new CountNodeSpec(count/divisor, requestedFlavor); }
-
+
@Override
public Node assignRequestedFlavor(Node node) {
// Docker nodes can change flavor in place
if (requestedFlavorCanBeAchievedByResizing(node.flavor()))
return node.with(requestedFlavor);
+
return node;
}
@@ -115,16 +116,19 @@ public interface NodeSpec {
/** Docker nodes can be downsized in place */
private boolean requestedFlavorCanBeAchievedByResizing(Flavor flavor) {
- return flavor.isDocker() && requestedFlavor.isDocker() && flavor.isLargerThan(requestedFlavor);
+ // TODO: Enable this when we can do it safely
+ // Then also re-enable ProvisioningTest.application_deployment_with_inplace_downsize()
+ // return flavor.isDocker() && requestedFlavor.isDocker() && flavor.isLargerThan(requestedFlavor);
+ return false;
}
-
+
}
/** A node spec specifying a node type. This will accept all nodes of this type. */
class TypeNodeSpec implements NodeSpec {
-
+
private final NodeType type;
-
+
public TypeNodeSpec(NodeType type) {
this.type = type;
}
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisioningTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisioningTest.java
index 75d5862f010..1dce5830540 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisioningTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisioningTest.java
@@ -97,11 +97,11 @@ public class DynamicDockerProvisioningTest {
* Test relocation of nodes from spare hosts.
* <p>
* Setup 4 docker hosts and allocate one container on each (from two different applications)
- * No headroom defined - only 2 spares.
+ * No headroom defined - only getSpareCapacityProd() spares.
* <p>
- * Check that it relocates containers away from the 2 spares
+ * Check that it relocates containers away from the getSpareCapacityProd() spares
* <p>
- * Initial allocation of app 1 and 2 --> final allocation:
+ * Initial allocation of app 1 and 2 --> final allocation (example using 2 spares):
* <p>
* | | | | | | | | | |
* | | | | | --> | 2a | 2b | | |
@@ -139,7 +139,8 @@ public class DynamicDockerProvisioningTest {
hostsWithChildren.add(node.parentHostname().get());
}
}
- Assert.assertEquals(2, hostsWithChildren.size());
+ Assert.assertEquals(4 - tester.provisioner().getSpareCapacityProd(), hostsWithChildren.size());
+
}
/**
@@ -389,8 +390,14 @@ public class DynamicDockerProvisioningTest {
// Verify that there is still capacity (available spare)
// Fail one node and redeploy, Verify that one less node is empty.
- // Setup test
+
ProvisioningTester tester = new ProvisioningTester(new Zone(Environment.prod, RegionName.from("us-east")), flavorsConfig());
+ // Only run test if there _is_ spare capacity
+ if (tester.provisioner().getSpareCapacityProd() == 0) {
+ return;
+ }
+
+ // Setup test
enableDynamicAllocation(tester);
ApplicationId application1 = tester.makeApplicationId();
tester.makeReadyNodes(5, "host-small", NodeType.host, 32);
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java
index afdce0d25cc..a7ea77618bb 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java
@@ -181,7 +181,7 @@ public class ProvisioningTest {
SystemState state5 = prepare(application1, 2, 2, 3, 3, "default", tester);
tester.activate(application1, state5.allHosts);
assertEquals("Superfluous container nodes are also deactivated",
- 4-2 + 5-2 + 1, tester.getNodes(application1, Node.State.inactive).size()); //
+ 4-2 + 5-2 + 1, tester.getNodes(application1, Node.State.inactive).size()); //
assertEquals("Superfluous content nodes are retired",
5-3 + 6-3 - 1, tester.getNodes(application1, Node.State.active).retired().size());
@@ -231,6 +231,8 @@ public class ProvisioningTest {
0, tester.getNodes(application1, Node.State.active).retired().flavor("large").size());
}
+ // TODO: Enable when this feature is re-enabled
+ @Ignore
@Test
public void application_deployment_with_inplace_downsize() {
ProvisioningTester tester = new ProvisioningTester(new Zone(Environment.prod, RegionName.from("us-east")));
@@ -761,7 +763,7 @@ public class ProvisioningTest {
if (nodeCount == 0) return Collections.emptySet(); // this is a shady practice
return new HashSet<>(tester.prepare(application, cluster, nodeCount, groups, flavor));
}
-
+
private static class SystemState {
private Set<HostSpec> allHosts;
@@ -781,7 +783,7 @@ public class ProvisioningTest {
this.content0 = content0;
this.content1 = content1;
}
-
+
/** Returns a host by cluster name and index, or null if there is no host with the given values in this */
public HostSpec hostByMembership(String clusterId, int group, int index) {
for (HostSpec host : allHosts) {
@@ -794,7 +796,7 @@ public class ProvisioningTest {
}
return null;
}
-
+
private boolean groupMatches(Optional<ClusterSpec.Group> clusterGroup, int group) {
if ( ! clusterGroup.isPresent()) return group==0;
return clusterGroup.get().index() == group;
diff --git a/persistence/src/tests/spi/CMakeLists.txt b/persistence/src/tests/spi/CMakeLists.txt
index a130573e028..c51270a420c 100644
--- a/persistence/src/tests/spi/CMakeLists.txt
+++ b/persistence/src/tests/spi/CMakeLists.txt
@@ -2,6 +2,7 @@
vespa_add_library(persistence_testspi
SOURCES
clusterstatetest.cpp
+ fixed_bucket_spaces_test.cpp
DEPENDS
persistence_persistence_conformancetest
persistence
diff --git a/persistence/src/tests/spi/fixed_bucket_spaces_test.cpp b/persistence/src/tests/spi/fixed_bucket_spaces_test.cpp
new file mode 100644
index 00000000000..7e36d80248a
--- /dev/null
+++ b/persistence/src/tests/spi/fixed_bucket_spaces_test.cpp
@@ -0,0 +1,64 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/persistence/spi/fixed_bucket_spaces.h>
+#include <cppunit/extensions/HelperMacros.h>
+
+namespace storage::spi {
+
+struct FixedBucketSpacesTest : CppUnit::TestFixture {
+ CPPUNIT_TEST_SUITE(FixedBucketSpacesTest);
+ CPPUNIT_TEST(bucket_space_from_name_is_defined_for_default_space);
+ CPPUNIT_TEST(bucket_space_from_name_is_defined_for_global_space);
+ CPPUNIT_TEST(bucket_space_from_name_throws_exception_for_unknown_space);
+ CPPUNIT_TEST(name_from_bucket_space_is_defined_for_default_space);
+ CPPUNIT_TEST(name_from_bucket_space_is_defined_for_global_space);
+ CPPUNIT_TEST(name_from_bucket_space_throws_exception_for_unknown_space);
+ CPPUNIT_TEST_SUITE_END();
+
+ void bucket_space_from_name_is_defined_for_default_space();
+ void bucket_space_from_name_is_defined_for_global_space();
+ void bucket_space_from_name_throws_exception_for_unknown_space();
+ void name_from_bucket_space_is_defined_for_default_space();
+ void name_from_bucket_space_is_defined_for_global_space();
+ void name_from_bucket_space_throws_exception_for_unknown_space();
+};
+
+CPPUNIT_TEST_SUITE_REGISTRATION(FixedBucketSpacesTest);
+
+using document::BucketSpace;
+
+void FixedBucketSpacesTest::bucket_space_from_name_is_defined_for_default_space() {
+ CPPUNIT_ASSERT_EQUAL(FixedBucketSpaces::default_space(), FixedBucketSpaces::from_string("default"));
+}
+
+void FixedBucketSpacesTest::bucket_space_from_name_is_defined_for_global_space() {
+ CPPUNIT_ASSERT_EQUAL(FixedBucketSpaces::global_space(), FixedBucketSpaces::from_string("global"));
+}
+
+void FixedBucketSpacesTest::bucket_space_from_name_throws_exception_for_unknown_space() {
+ try {
+ FixedBucketSpaces::from_string("banana");
+ CPPUNIT_FAIL("Expected exception on unknown bucket space name");
+ } catch (spi::UnknownBucketSpaceException& e) {
+ }
+}
+
+void FixedBucketSpacesTest::name_from_bucket_space_is_defined_for_default_space() {
+ CPPUNIT_ASSERT_EQUAL(vespalib::stringref("default"),
+ FixedBucketSpaces::to_string(FixedBucketSpaces::default_space()));
+}
+
+void FixedBucketSpacesTest::name_from_bucket_space_is_defined_for_global_space() {
+ CPPUNIT_ASSERT_EQUAL(vespalib::stringref("global"),
+ FixedBucketSpaces::to_string(FixedBucketSpaces::global_space()));
+}
+
+void FixedBucketSpacesTest::name_from_bucket_space_throws_exception_for_unknown_space() {
+ try {
+ FixedBucketSpaces::to_string(BucketSpace(4567));
+ CPPUNIT_FAIL("Expected exception on unknown bucket space value");
+ } catch (spi::UnknownBucketSpaceException& e) {
+ }
+}
+
+} \ No newline at end of file
diff --git a/persistence/src/vespa/persistence/conformancetest/conformancetest.cpp b/persistence/src/vespa/persistence/conformancetest/conformancetest.cpp
index 885d3e9aad7..7f4ea9dcc2e 100644
--- a/persistence/src/vespa/persistence/conformancetest/conformancetest.cpp
+++ b/persistence/src/vespa/persistence/conformancetest/conformancetest.cpp
@@ -8,6 +8,7 @@
#include <vespa/document/update/documentupdate.h>
#include <vespa/document/update/assignvalueupdate.h>
#include <vespa/document/test/make_bucket_space.h>
+#include <vespa/metrics/loadmetric.h>
#include <vespa/vdslib/state/state.h>
#include <vespa/vdslib/state/node.h>
#include <vespa/vdslib/state/nodestate.h>
diff --git a/persistence/src/vespa/persistence/spi/CMakeLists.txt b/persistence/src/vespa/persistence/spi/CMakeLists.txt
index a8b1faadcd3..a2b8fa7a79c 100644
--- a/persistence/src/vespa/persistence/spi/CMakeLists.txt
+++ b/persistence/src/vespa/persistence/spi/CMakeLists.txt
@@ -1,19 +1,20 @@
# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
vespa_add_library(persistence_spi OBJECT
SOURCES
+ abstractpersistenceprovider.cpp
bucket.cpp
bucketinfo.cpp
- exceptions.cpp
- persistenceprovider.cpp
- partitionstate.cpp
- abstractpersistenceprovider.cpp
clusterstate.cpp
context.cpp
+ docentry.cpp
+ exceptions.cpp
+ fixed_bucket_spaces.cpp
metricpersistenceprovider.cpp
+ partitionstate.cpp
+ persistenceprovider.cpp
read_consistency.cpp
- result
+ result.cpp
selection.cpp
test.cpp
- docentry
DEPENDS
)
diff --git a/persistence/src/vespa/persistence/spi/context.h b/persistence/src/vespa/persistence/spi/context.h
index 75d3eac4538..ca4c79e3005 100644
--- a/persistence/src/vespa/persistence/spi/context.h
+++ b/persistence/src/vespa/persistence/spi/context.h
@@ -29,7 +29,6 @@
#pragma once
-#include <vespa/metrics/loadmetric.h>
#include <persistence/spi/types.h>
#include <vespa/persistence/spi/read_consistency.h>
#include <vespa/vespalib/trace/trace.h>
@@ -38,8 +37,7 @@ namespace metrics {
class LoadType;
}
-namespace storage {
-namespace spi {
+namespace storage::spi {
using LoadType = metrics::LoadType;
@@ -93,6 +91,4 @@ public:
{ _trace.trace(level, msg, addTime); }
};
-} // spi
-} // storage
-
+}
diff --git a/persistence/src/vespa/persistence/spi/fixed_bucket_spaces.cpp b/persistence/src/vespa/persistence/spi/fixed_bucket_spaces.cpp
new file mode 100644
index 00000000000..6a8ec0f18f7
--- /dev/null
+++ b/persistence/src/vespa/persistence/spi/fixed_bucket_spaces.cpp
@@ -0,0 +1,33 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#include "fixed_bucket_spaces.h"
+
+namespace storage::spi {
+
+VESPA_IMPLEMENT_EXCEPTION(UnknownBucketSpaceException, vespalib::IllegalArgumentException)
+
+// Some sanity checks to ensure we don't mess up any legacy mappings.
+static_assert(document::BucketSpace::placeHolder() != document::BucketSpace::invalid());
+static_assert(FixedBucketSpaces::default_space() == document::BucketSpace::placeHolder());
+static_assert(FixedBucketSpaces::global_space() != FixedBucketSpaces::default_space());
+
+document::BucketSpace FixedBucketSpaces::from_string(vespalib::stringref name) {
+ if (name == "default") {
+ return default_space();
+ } else if (name == "global") {
+ return global_space();
+ } else {
+ throw UnknownBucketSpaceException("Unknown bucket space name: " + vespalib::string(name), VESPA_STRLOC);
+ }
+}
+
+vespalib::stringref FixedBucketSpaces::to_string(document::BucketSpace space) {
+ if (space == default_space()) {
+ return "default";
+ } else if (space == global_space()) {
+ return "global";
+ } else {
+ throw UnknownBucketSpaceException("Unknown bucket space: " + space.toString(), VESPA_STRLOC);
+ }
+}
+
+}
diff --git a/persistence/src/vespa/persistence/spi/fixed_bucket_spaces.h b/persistence/src/vespa/persistence/spi/fixed_bucket_spaces.h
new file mode 100644
index 00000000000..c2e97407797
--- /dev/null
+++ b/persistence/src/vespa/persistence/spi/fixed_bucket_spaces.h
@@ -0,0 +1,30 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#pragma once
+
+#include <vespa/document/bucket/bucketspace.h>
+#include <vespa/vespalib/util/exceptions.h>
+#include <vespa/vespalib/stllike/string.h>
+
+namespace storage::spi {
+
+VESPA_DEFINE_EXCEPTION(UnknownBucketSpaceException, vespalib::IllegalArgumentException);
+
+/**
+ * Minimal repository/factory of bucket spaces hard coded for default and global
+ * distributions.
+ */
+struct FixedBucketSpaces {
+ static constexpr document::BucketSpace default_space() { return document::BucketSpace(1); };
+ static constexpr document::BucketSpace global_space() { return document::BucketSpace(2); }
+
+ // Post-condition: returned space has valid() == true iff name
+ // is either "default" or "global".
+ // Throws UnknownBucketSpaceException if name does not map to a known bucket space.
+ static document::BucketSpace from_string(vespalib::stringref name);
+ // Post-condition: returned string can be losslessly passed to from_string()
+ // iff space is equal to default_space() or global_space().
+ // Throws UnknownBucketSpaceException if space does not map to a known name.
+ static vespalib::stringref to_string(document::BucketSpace space);
+};
+
+}
diff --git a/persistence/src/vespa/persistence/spi/metricpersistenceprovider.cpp b/persistence/src/vespa/persistence/spi/metricpersistenceprovider.cpp
index 76b0a3c4686..58e662a2b1d 100644
--- a/persistence/src/vespa/persistence/spi/metricpersistenceprovider.cpp
+++ b/persistence/src/vespa/persistence/spi/metricpersistenceprovider.cpp
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "metricpersistenceprovider.h"
+#include <vespa/metrics/valuemetric.h>
+#include <vespa/metrics/metrictimer.h>
#include <cassert>
#include <vespa/log/log.h>
diff --git a/persistence/src/vespa/persistence/spi/metricpersistenceprovider.h b/persistence/src/vespa/persistence/spi/metricpersistenceprovider.h
index e169ad098c7..b804fd21550 100644
--- a/persistence/src/vespa/persistence/spi/metricpersistenceprovider.h
+++ b/persistence/src/vespa/persistence/spi/metricpersistenceprovider.h
@@ -6,10 +6,10 @@
#pragma once
#include "persistenceprovider.h"
-#include <vespa/metrics/metrics.h>
+#include <vespa/metrics/metricset.h>
+#include <vespa/metrics/valuemetric.h>
-namespace storage {
-namespace spi {
+namespace storage::spi {
class MetricPersistenceProvider : public PersistenceProvider,
public metrics::MetricSet
@@ -61,5 +61,5 @@ private:
void defineResultMetrics(int index, const char* name);
};
-} // spi
-} // storage
+}
+
diff --git a/pom.xml b/pom.xml
index eb1f954ce13..b196034380b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -23,6 +23,941 @@
</developer>
</developers>
+ <distributionManagement>
+ <repository>
+ <id>bintray-vespa-repo</id>
+ <url>https://api.bintray.com/maven/yahoo/maven/vespa;publish=1</url>
+ </repository>
+ </distributionManagement>
+
+ <repositories>
+ <!-- Required for Athenz libraries -->
+ <repository>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
+ <id>bintray-yahoo-maven</id>
+ <name>bintray</name>
+ <url>https://yahoo.bintray.com/maven</url>
+ </repository>
+ </repositories>
+
+ <scm>
+ <connection>scm:git:git@github.com:vespa-engine/vespa.git</connection>
+ <developerConnection>scm:git:git@github.com:vespa-engine/vespa.git</developerConnection>
+ <url>git@github.com:vespa-engine/vespa.git</url>
+ </scm>
+
+ <build>
+ <finalName>${project.artifactId}</finalName>
+ <extensions>
+ <extension>
+ <groupId>org.apache.maven.wagon</groupId>
+ <artifactId>wagon-ssh-external</artifactId>
+ <version>2.7</version>
+ </extension>
+ <extension>
+ <groupId>org.apache.maven.archetype</groupId>
+ <artifactId>archetype-packaging</artifactId>
+ <version>2.0</version>
+ </extension>
+ </extensions>
+ <pluginManagement>
+ <plugins>
+ <plugin>
+ <groupId>org.antlr</groupId>
+ <artifactId>antlr3-maven-plugin</artifactId>
+ <version>${antlr.version}</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-antrun-plugin</artifactId>
+ <version>1.7</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.felix</groupId>
+ <artifactId>maven-bundle-plugin</artifactId>
+ <version>2.4.0</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-assembly-plugin</artifactId>
+ <version>2.4</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <version>3.6.1</version>
+ <configuration>
+ <source>1.8</source>
+ <target>1.8</target>
+ <showWarnings>true</showWarnings>
+ <optimize>true</optimize>
+ <showDeprecation>false</showDeprecation>
+ <compilerArgs>
+ <arg>-Xlint:all</arg>
+ <arg>-Xlint:-serial</arg>
+ <arg>-Xlint:-try</arg>
+ <arg>-Xlint:-processing</arg>
+ <arg>-Xlint:-varargs</arg>
+ <arg>-Werror</arg>
+ </compilerArgs>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-dependency-plugin</artifactId>
+ <version>2.10</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-deploy-plugin</artifactId>
+ <version>2.5</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-install-plugin</artifactId>
+ <version>2.5.2</version>
+ <configuration>
+ <updateReleaseInfo>true</updateReleaseInfo>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <version>3.0.2</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-javadoc-plugin</artifactId>
+ <configuration>
+ <additionalparam>-Xdoclint:${doclint} -Xdoclint:-missing</additionalparam>
+ </configuration>
+ <version>2.10.4</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-plugin-plugin</artifactId>
+ <version>3.5</version>
+ <configuration>
+ <!-- see http://jira.codehaus.org/browse/MNG-5346 -->
+ <skipErrorNoDescriptorsFound>true</skipErrorNoDescriptorsFound>
+ </configuration>
+ <executions>
+ <execution>
+ <id>mojo-descriptor</id>
+ <goals>
+ <goal>descriptor</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-resources-plugin</artifactId>
+ <version>2.7</version>
+ <configuration>
+ <escapeString>\</escapeString>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-site-plugin</artifactId>
+ <version>3.3</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-source-plugin</artifactId>
+ <version>2.1.2</version>
+ <configuration>
+ <includePom>true</includePom>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-surefire-plugin</artifactId>
+ <version>${surefire.version}</version>
+ <configuration>
+ <redirectTestOutputToFile>${test.hide}</redirectTestOutputToFile>
+ <systemPropertyVariables>
+ <java.io.tmpdir>${project.build.directory}</java.io.tmpdir>
+ </systemPropertyVariables>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-surefire-report-plugin</artifactId>
+ <version>${surefire.version}</version>
+ <configuration>
+ <alwaysGenerateSurefireReport>false</alwaysGenerateSurefireReport>
+ <showSuccess>false</showSuccess>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>build-helper-maven-plugin</artifactId>
+ <version>1.9.1</version>
+ </plugin>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>exec-maven-plugin</artifactId>
+ <version>1.6.0</version>
+ </plugin>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>javacc-maven-plugin</artifactId>
+ <version>2.6</version>
+ </plugin>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>properties-maven-plugin</artifactId>
+ <version>1.0.0</version>
+ </plugin>
+ <plugin>
+ <groupId>net.alchim31.maven</groupId>
+ <artifactId>scala-maven-plugin</artifactId>
+ <version>3.2.2</version>
+ <configuration>
+ <args>
+ <arg>-unchecked</arg>
+ <arg>-deprecation</arg>
+ <arg>-feature</arg>
+ <arg>-Xfatal-warnings</arg>
+ </args>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>bundle-plugin</artifactId>
+ <version>${project.version}</version>
+ <configuration>
+ <configGenVersion>${project.version}</configGenVersion>
+ <useCommonAssemblyIds>true</useCommonAssemblyIds>
+ </configuration>
+ </plugin>
+ </plugins>
+ </pluginManagement>
+ </build>
+ <profiles>
+ <profile>
+ <id>attach-sources</id>
+ <activation>
+ <property>
+ <name>!skipSources</name>
+ </property>
+ </activation>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-source-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>attach-sources</id>
+ <goals>
+ <goal>jar-no-fork</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>generate-javadoc</id>
+ <activation>
+ <property>
+ <name>!skipJavadoc</name>
+ </property>
+ </activation>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-javadoc-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>generate-javadoc</id>
+ <phase>package</phase>
+ <goals>
+ <goal>javadoc</goal>
+ </goals>
+ </execution>
+ </executions>
+ <configuration>
+ <additionalparam>-Xdoclint:${doclint} -Xdoclint:-missing</additionalparam>
+ <failOnError>${javadoc.failOnError}</failOnError>
+ <quiet>true</quiet>
+ <show>private</show>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>coverage</id>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>exec-maven-plugin</artifactId>
+ <configuration>
+ <includePluginDependencies>true</includePluginDependencies>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>build-helper-maven-plugin</artifactId>
+ <executions>
+ <execution>
+ <phase>generate-sources</phase>
+ <goals>
+ <goal>add-source</goal>
+ </goals>
+ <configuration>
+ <sources>
+ <source>src/main/scala</source>
+ </sources>
+ </configuration>
+ </execution>
+ <execution>
+ <id>add-test-source</id>
+ <phase>generate-test-sources</phase>
+ <goals>
+ <goal>add-test-source</goal>
+ </goals>
+ <configuration>
+ <sources>
+ <source>src/test/scala</source>
+ </sources>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>sign-artifacts</id>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ <version>1.6</version>
+ <executions>
+ <execution>
+ <id>sign-artifacts</id>
+ <phase>verify</phase>
+ <goals>
+ <goal>sign</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+ <dependencyManagement>
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.maven.wagon</groupId>
+ <artifactId>wagon-ssh-external</artifactId>
+ <version>2.7</version>
+ </dependency>
+ <dependency>
+ <groupId>com.github.cverges.expect4j</groupId>
+ <artifactId>expect4j</artifactId>
+ <version>1.6</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-compress</artifactId>
+ <version>1.11</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-exec</artifactId>
+ <version>1.3</version>
+ </dependency>
+ <dependency>
+ <groupId>io.airlift</groupId>
+ <artifactId>airline</artifactId>
+ <version>0.7</version>
+ </dependency>
+ <dependency>
+ <groupId>aopalliance</groupId>
+ <artifactId>aopalliance</artifactId>
+ <version>1.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.ow2.asm</groupId>
+ <artifactId>asm</artifactId>
+ <version>5.2</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.code.findbugs</groupId>
+ <artifactId>annotations</artifactId>
+ <version>1.3.9</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.code.findbugs</groupId>
+ <artifactId>jsr305</artifactId>
+ <version>1.3.9</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <version>18.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava-testlib</artifactId>
+ <version>18.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.inject</groupId>
+ <artifactId>guice</artifactId>
+ <version>3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.inject</groupId>
+ <artifactId>guice</artifactId>
+ <version>3.0</version>
+ <classifier>no_aop</classifier>
+ </dependency>
+ <dependency>
+ <groupId>com.google.inject.extensions</groupId>
+ <artifactId>guice-assistedinject</artifactId>
+ <version>3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.inject.extensions</groupId>
+ <artifactId>guice-multibindings</artifactId>
+ <version>3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>3.4.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.googlecode.jmockit</groupId>
+ <artifactId>jmockit</artifactId>
+ <version>1.2</version>
+ </dependency>
+ <dependency>
+ <groupId>com.goldmansachs</groupId>
+ <artifactId>gs-collections</artifactId>
+ <version>6.1.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-core</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-annotations</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.jaxrs</groupId>
+ <artifactId>jackson-jaxrs-json-provider</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.module</groupId>
+ <artifactId>jackson-module-jaxb-annotations</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.jaxrs</groupId>
+ <artifactId>jackson-jaxrs-base</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.jaxrs</groupId>
+ <artifactId>jackson-jaxrs-xml-provider</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.dataformat</groupId>
+ <artifactId>jackson-dataformat-xml</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.datatype</groupId>
+ <artifactId>jackson-datatype-jdk8</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.datatype</groupId>
+ <artifactId>jackson-datatype-jsr310</artifactId>
+ <version>${jackson2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.infradna.tool</groupId>
+ <artifactId>bridge-method-annotation</artifactId>
+ <version>1.4</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-cli</groupId>
+ <artifactId>commons-cli</artifactId>
+ <version>1.3.1</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-codec</groupId>
+ <artifactId>commons-codec</artifactId>
+ <version>1.4</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-collections</groupId>
+ <artifactId>commons-collections</artifactId>
+ <version>3.2.1</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-configuration</groupId>
+ <artifactId>commons-configuration</artifactId>
+ <version>1.6</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-daemon</groupId>
+ <artifactId>commons-daemon</artifactId>
+ <version>1.0.3</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-io</groupId>
+ <artifactId>commons-io</artifactId>
+ <version>2.4</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-lang</groupId>
+ <artifactId>commons-lang</artifactId>
+ <version>${commons-lang.version}</version>
+ </dependency>
+ <dependency>
+ <!-- This version is exported by jdisc via jcl-over-slf4j. -->
+ <groupId>commons-logging</groupId>
+ <artifactId>commons-logging</artifactId>
+ <version>1.1.1</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-net</groupId>
+ <artifactId>commons-net</artifactId>
+ <version>2.0</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-pool</groupId>
+ <artifactId>commons-pool</artifactId>
+ <version>1.5.6</version>
+ </dependency>
+ <!-- Explicitly included to get Zookeeper version 3.4.10,
+ can be excluded if you want the Zookeeper version
+ used by curator by default
+ -->
+ <dependency>
+ <groupId>org.apache.zookeeper</groupId>
+ <artifactId>zookeeper</artifactId>
+ <version>3.4.10</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.curator</groupId>
+ <artifactId>curator-recipes</artifactId>
+ <version>${curator.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.curator</groupId>
+ <artifactId>curator-test</artifactId>
+ <version>${curator.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>javax.servlet</groupId>
+ <artifactId>javax.servlet-api</artifactId>
+ <version>3.1.0</version>
+ </dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <version>4.12</version>
+ </dependency>
+ <dependency>
+ <groupId>org.antlr</groupId>
+ <artifactId>antlr-runtime</artifactId>
+ <version>${antlr.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.antlr</groupId>
+ <artifactId>antlr4-runtime</artifactId>
+ <version>${antlr4.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.aries.spifly</groupId>
+ <artifactId>org.apache.aries.spifly.dynamic.bundle</artifactId>
+ <version>${aries.spifly.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-lang3</artifactId>
+ <version>3.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.felix</groupId>
+ <artifactId>org.apache.felix.framework</artifactId>
+ <version>4.2.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.felix</groupId>
+ <artifactId>org.apache.felix.log</artifactId>
+ <version>1.0.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.felix</groupId>
+ <artifactId>org.apache.felix.main</artifactId>
+ <version>4.2.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>fluent-hc</artifactId>
+ <version>4.3.6</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>httpclient</artifactId>
+ <version>4.3.6</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>httpcore</artifactId>
+ <version>4.3.3</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>httpmime</artifactId>
+ <version>4.3.6</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven</groupId>
+ <artifactId>maven-artifact</artifactId>
+ <version>3.5.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven</groupId>
+ <artifactId>maven-core</artifactId>
+ <version>3.5.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven</groupId>
+ <artifactId>maven-model</artifactId>
+ <version>3.5.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven.plugin-tools</groupId>
+ <artifactId>maven-plugin-annotations</artifactId>
+ <version>3.5</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven</groupId>
+ <artifactId>maven-plugin-api</artifactId>
+ <version>3.5.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven</groupId>
+ <artifactId>maven-project</artifactId>
+ <version>2.2.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <version>3.0.2</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven.surefire</groupId>
+ <artifactId>surefire-junit4</artifactId>
+ <version>${surefire.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.maven.surefire</groupId>
+ <artifactId>surefire-providers</artifactId>
+ <version>${surefire.version}</version>
+ <type>pom</type>
+ </dependency>
+ <dependency>
+ <groupId>org.codehaus.jettison</groupId>
+ <artifactId>jettison</artifactId>
+ <version>1.3.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.cthul</groupId>
+ <artifactId>cthul-matchers</artifactId>
+ <version>1.0</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-continuation</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-server</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-servlet</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-servlets</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-util</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-http</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-jmx</artifactId>
+ <version>${jetty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-all</artifactId>
+ <version>1.3</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-core</artifactId>
+ <version>1.3</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-library</artifactId>
+ <version>1.3</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>uk.co.datumedge</groupId>
+ <artifactId>hamcrest-json</artifactId>
+ <version>0.2</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.hdrhistogram</groupId>
+ <artifactId>HdrHistogram</artifactId>
+ <version>2.1.8</version>
+ </dependency>
+ <dependency>
+ <groupId>org.json</groupId>
+ <artifactId>json</artifactId>
+ <version>20090211</version>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <version>1.9.5</version>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <version>1.9.5</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.osgi</groupId>
+ <artifactId>org.osgi.compendium</artifactId>
+ <version>4.3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.osgi</groupId>
+ <artifactId>org.osgi.core</artifactId>
+ <version>4.3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.scala-lang</groupId>
+ <artifactId>scala-library</artifactId>
+ <version>${scala.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.scala-lang.modules</groupId>
+ <artifactId>scala-parser-combinators_${scala.major-version}</artifactId>
+ <version>1.0.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.scala-lang.modules</groupId>
+ <artifactId>scala-xml_${scala.major-version}</artifactId>
+ <version>1.0.2</version>
+ </dependency>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.major-version}</artifactId>
+ <version>2.2.2</version>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>jcl-over-slf4j</artifactId>
+ <version>1.7.5</version>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>log4j-over-slf4j</artifactId>
+ <version>1.7.5</version>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ <version>1.7.5</version>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-jdk14</artifactId>
+ <version>1.7.5</version>
+ </dependency>
+ <dependency>
+ <groupId>org.springframework</groupId>
+ <artifactId>spring-test</artifactId>
+ <version>4.0.6.RELEASE</version>
+ </dependency>
+ <dependency>
+ <groupId>org.testng</groupId>
+ <artifactId>testng</artifactId>
+ <version>6.10</version>
+ </dependency>
+ <dependency>
+ <groupId>org.twdata.maven</groupId>
+ <artifactId>mojo-executor</artifactId>
+ <version>2.3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>net.jcip</groupId>
+ <artifactId>jcip-annotations</artifactId>
+ <version>1.0</version>
+ </dependency>
+ <dependency>
+ <groupId>net.jpountz.lz4</groupId>
+ <artifactId>lz4</artifactId>
+ <version>1.3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>net.spy</groupId>
+ <artifactId>spymemcached</artifactId>
+ <version>2.10.1</version>
+ </dependency>
+ <dependency>
+ <groupId>xerces</groupId>
+ <artifactId>xercesImpl</artifactId>
+ <version>2.11.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.bouncycastle</groupId>
+ <artifactId>bcpkix-jdk15on</artifactId>
+ <version>${bouncycastle.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.bouncycastle</groupId>
+ <artifactId>bcprov-jdk15on</artifactId>
+ <version>${bouncycastle.version}</version>
+ </dependency>
+ <!-- jersey 2 support -->
+ <dependency>
+ <groupId>javax.ws.rs</groupId>
+ <artifactId>javax.ws.rs-api</artifactId>
+ <version>${javax.ws.rs-api.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.glassfish.jersey.containers</groupId>
+ <artifactId>jersey-container-servlet-core</artifactId>
+ <version>${jersey2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.glassfish.jersey.containers</groupId>
+ <artifactId>jersey-container-servlet</artifactId>
+ <version>${jersey2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.glassfish.jersey.media</groupId>
+ <artifactId>jersey-media-json-jackson</artifactId>
+ <version>${jersey2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.glassfish.jersey.media</groupId>
+ <artifactId>jersey-media-multipart</artifactId>
+ <version>${jersey2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.glassfish.jersey.ext</groupId>
+ <artifactId>jersey-proxy-client</artifactId>
+ <version>${jersey2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.glassfish.jersey.core</groupId>
+ <artifactId>jersey-client</artifactId>
+ <version>${jersey2.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.ibm.icu</groupId>
+ <artifactId>icu4j</artifactId>
+ <version>57.1</version>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.athenz</groupId>
+ <artifactId>athenz-zms-java-client</artifactId>
+ <version>${athenz.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.athenz</groupId>
+ <artifactId>athenz-zts-java-client</artifactId>
+ <version>${athenz.version}</version>
+ </dependency>
+ </dependencies>
+ </dependencyManagement>
+
+ <properties>
+ <javax.ws.rs-api.version>2.0.1</javax.ws.rs-api.version> <!-- must be kept in sync with version used by current jersey2.version -->
+ <antlr.version>3.5.2</antlr.version>
+ <antlr4.version>4.5</antlr4.version>
+ <aries.spifly.version>1.0.8</aries.spifly.version>
+ <aries.util.version>1.0.0</aries.util.version>
+ <asm-debug-all.version>5.0.3</asm-debug-all.version>
+ <!-- Athenz dependencies. Make sure these dependencies matches those in Vespa's internal repositories -->
+ <athenz.version>1.7.28</athenz.version>
+ <bouncycastle.version>1.58</bouncycastle.version>
+ <commons-lang.version>2.6</commons-lang.version>
+ <!-- WARNING: If you change curator version, you also need to update
+ zkfacade/src/main/java/org/apache/curator/**/package-info.java
+ using something like
+ find zkfacade/src/main/java/org/apache/curator -name package-info.java | \
+ xargs perl -pi -e 's/major = [0-9]+, minor = [0-9]+, micro = [0-9]+/major = 2, minor = 9, micro = 1/g'
+ -->
+ <curator.version>2.9.1</curator.version>
+ <jackson2.version>2.8.3</jackson2.version>
+ <jersey2.version>2.23.2</jersey2.version>
+ <jetty.version>9.4.6.v20170531</jetty.version>
+ <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+ <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
+ <test.hide>true</test.hide>
+ <doclint>all</doclint>
+ <scala.major-version>2.11</scala.major-version>
+ <scala.version>${scala.major-version}.4</scala.version>
+ <surefire.version>2.19.1</surefire.version> <!-- NOTE bjorncs 15.06.2017: Version 2.20 has OoM issues -->
+ </properties>
+
<modules>
<module>application</module>
<module>application-deploy-plugin</module>
diff --git a/searchcore/src/tests/proton/persistenceengine/persistenceengine_test.cpp b/searchcore/src/tests/proton/persistenceengine/persistenceengine_test.cpp
index 4a195514db1..4b73e4ca115 100644
--- a/searchcore/src/tests/proton/persistenceengine/persistenceengine_test.cpp
+++ b/searchcore/src/tests/proton/persistenceengine/persistenceengine_test.cpp
@@ -13,6 +13,7 @@
#include <vespa/searchcore/proton/persistenceengine/persistenceengine.h>
#include <vespa/vdslib/distribution/distribution.h>
#include <vespa/vdslib/state/clusterstate.h>
+#include <vespa/metrics/loadmetric.h>
#include <vespa/vespalib/testkit/testapp.h>
#include <algorithm>
#include <set>
@@ -369,7 +370,7 @@ Timestamp tstamp2(2);
Timestamp tstamp3(3);
DocumentSelection doc_sel("");
Selection selection(doc_sel);
-BucketSpace altBucketSpace(1);
+BucketSpace altBucketSpace(2);
void
diff --git a/searchcore/src/vespa/searchcore/proton/persistenceengine/persistenceengine.cpp b/searchcore/src/vespa/searchcore/proton/persistenceengine/persistenceengine.cpp
index c7b01d209ee..e2b389fb898 100644
--- a/searchcore/src/vespa/searchcore/proton/persistenceengine/persistenceengine.cpp
+++ b/searchcore/src/vespa/searchcore/proton/persistenceengine/persistenceengine.cpp
@@ -3,8 +3,8 @@
#include "persistenceengine.h"
#include "ipersistenceengineowner.h"
#include "transport_latch.h"
+#include <vespa/metrics/loadmetric.h>
#include <vespa/vespalib/stllike/hash_set.h>
-#include <vespa/fastos/thread.h>
#include <vespa/log/log.h>
LOG_SETUP(".proton.persistenceengine.persistenceengine");
@@ -23,6 +23,8 @@ using vespalib::IllegalStateException;
using vespalib::Sequence;
using vespalib::make_string;
+using namespace std::chrono_literals;
+
namespace proton {
namespace {
@@ -623,7 +625,7 @@ PersistenceEngine::destroyIterators()
Result res(destroyIterator(id, context));
if (res.hasError()) {
LOG(debug, "%ld iterator left. Can not destroy iterator '%ld'. Reason='%s'", _iterators.size(), id.getValue(), res.toString().c_str());
- FastOS_Thread::Sleep(100); // Sleep 0.1 seconds
+ std::this_thread::sleep_for(100ms);
}
}
}
diff --git a/searchlib/pom.xml b/searchlib/pom.xml
index 5f6717d9516..09ccf9928b7 100644
--- a/searchlib/pom.xml
+++ b/searchlib/pom.xml
@@ -36,6 +36,21 @@
<version>${project.version}</version>
</dependency>
<dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>3.4.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ <version>1.4.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>tensorflow</artifactId>
+ <version>1.4.0</version>
+ </dependency>
+ <dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<scope>test</scope>
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 785ed78492e..0eeb0a9e630 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Set;
@@ -18,26 +19,30 @@ public abstract class Context implements EvaluationContext {
/**
* <p>Returns the value of a simple variable name.</p>
*
- * @param name The name of the variable whose value to return.
- * @return The value of the named variable.
+ * @param name the name of the variable whose value to return.
+ * @return the value of the named variable.
*/
public abstract Value get(String name);
+ /** Returns a variable as a tensor */
+ @Override
+ public Tensor getTensor(String name) { return get(name).asTensor(); }
+
/**
* <p>Returns the value of a <i>structured variable</i> on the form
* <code>name(argument*)(.output)?</code>, where <i>argument</i> is any
* string. This may be used to implement more advanced variables whose
* values are calculated at runtime from arguments. Supporting this in a
- * context is optional.
- *
+ * context is optional.
+ *
* <p>This default implementation generates a name on the form
* <code>name(argument1, argument2, ...argumentN).output</code>.
* If there are no arguments the parenthesis are omitted.
* If there is no output, the dot is omitted.</p>
*
- * @param name The name of this variable.
- * @param arguments The parsed arguments as given in the textual expression.
- * @param output The name of the value to output (to enable one named
+ * @param name the name of this variable.
+ * @param arguments the parsed arguments as given in the textual expression.
+ * @param output the name of the value to output (to enable one named
* calculation to output several), or null to output the
* "main" (or only) value.
*/
@@ -54,20 +59,20 @@ public abstract class Context implements EvaluationContext {
* context subclasses. This default implementation throws
* UnsupportedOperationException.</p>
*
- * @param index The index of the variable whose value to return.
- * @return The value of the indexed variable.
+ * @param index the index of the variable whose value to return.
+ * @return the value of the indexed variable.
*/
public Value get(int index) {
throw new UnsupportedOperationException(this + " does not support variable lookup by index");
}
/**
- * <p>Lookup by index rather than name directly to a double. This is supported by some optimized
+ * Lookup by index rather than name directly to a double. This is supported by some optimized
* context subclasses. This default implementation throws
- * UnsupportedOperationException.</p>
+ * UnsupportedOperationException.
*
- * @param index The index of the variable whose value to return.
- * @return The value of the indexed variable.
+ * @param index the index of the variable whose value to return.
+ * @return the value of the indexed variable.
*/
public double getDouble(int index) {
throw new UnsupportedOperationException(this + " does not support variable lookup by index");
@@ -81,24 +86,23 @@ public abstract class Context implements EvaluationContext {
}
/**
- * <p>Sets a value to this, or throws an UnsupportedOperationException if
- * this is not supported. This default implementation does the latter.</p> *
+ * Sets a value to this, or throws an UnsupportedOperationException if
+ * this is not supported. This default implementation does the latter.
*
- * @param name The name of the variable to set.
+ * @param name the name of the variable to set.
* @param value the value to set. Ownership of this value is transferred to this - if it is mutable
* (not frozen) it may be modified during execution
- * @since 5.1.5
*/
public void put(String name, Value value) {
throw new UnsupportedOperationException(this + " does not support variable assignment");
}
/**
- * <p>Returns all the names available in this, or throws an
+ * Returns all the names available in this, or throws an
* UnsupportedOperationException if this operation is not supported. This
- * default implementation does the latter.</p>
+ * default implementation does the latter.
*
- * @return The set of all variable names.
+ * @return the set of all variable names.
*/
public Set<String> names() {
throw new UnsupportedOperationException(this + " does not support return a list of its names");
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index ea750295423..2ef4a2ede2f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -3,6 +3,9 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* A value which acts as a double in numerical context.
@@ -16,6 +19,11 @@ public abstract class DoubleCompatibleValue extends Value {
public boolean hasDouble() { return true; }
@Override
+ public Tensor asTensor() {
+ return doubleAsTensor(asDouble());
+ }
+
+ @Override
public Value negate() { return new DoubleValue(-asDouble()); }
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index ac8aba6a617..dad69b31181 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -4,12 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* A string value.
*
* @author bratseth
- * @since 5.1.21
*/
public class StringValue extends Value {
@@ -35,6 +37,11 @@ public class StringValue extends Value {
}
@Override
+ public Tensor asTensor() {
+ return doubleAsTensor(asDouble());
+ }
+
+ @Override
public boolean hasDouble() { return true; }
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 49c3ccb7b01..26c30fe5ed2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -2,14 +2,10 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.google.common.annotations.Beta;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
-import com.yahoo.tensor.TensorType;
-
-import java.util.Collections;
-import java.util.Optional;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
/**
* A Value containing a tensor.
@@ -23,7 +19,7 @@ public class TensorValue extends Value {
/** The tensor value of this */
private final Tensor value;
-
+
public TensorValue(Tensor value) {
this.value = value;
}
@@ -131,7 +127,7 @@ public class TensorValue extends Value {
public Value compare(TruthOperator operator, Value argument) {
return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString())));
}
-
+
private Tensor compareTensor(TruthOperator operator, Tensor argument) {
switch (operator) {
case LARGER: return value.larger(argument);
@@ -152,7 +148,7 @@ public class TensorValue extends Value {
else
return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble())));
}
-
+
private Tensor functionOnTensor(Function function, Tensor argument) {
switch (function) {
case min: return value.min(argument);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index b2ccbe572d0..40d70e0022c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -5,6 +5,8 @@ import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* The result of a ranking expression evaluation.
@@ -25,6 +27,14 @@ public abstract class Value {
return new DoubleValue(asDouble());
}
+ /** Returns this as a tensor value */
+ public abstract Tensor asTensor();
+
+ /** A utility method for wrapping a sdouble in a rank 0 tensor */
+ protected Tensor doubleAsTensor(double value) {
+ return Tensor.Builder.of(TensorType.empty).cell(TensorAddress.of(), value).build();
+ }
+
/** Returns true if this value can return itself as a double, i.e asDoubleValue will return a value and not throw */
public abstract boolean hasDouble();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
new file mode 100644
index 00000000000..947e6d7a5e1
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
@@ -0,0 +1,102 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * The result of importing a TensorFlow model into Vespa.
+ * - A set of signatures which are named collections of inputs and outputs.
+ * - A set of named constant tensors represented by Variable nodes in TensorFlow.
+ * - A list of warning messages.
+ *
+ * @author bratseth
+ */
+// This object can be built incrementally within this package, but is immutable when observed from outside the package
+public class ImportResult {
+
+ private final Map<String, Signature> signatures = new HashMap<>();
+ private final Map<String, TensorType> arguments = new HashMap<>();
+ private final Map<String, Tensor> constants = new HashMap<>();
+ private final Map<String, RankingExpression> expressions = new HashMap<>();
+ private final List<String> warnings = new ArrayList<>();
+
+ void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
+ void constant(String name, Tensor constant) { constants.put(name, constant); }
+ void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
+ void warn(String warning) { warnings.add(warning); }
+
+ /** Returns the given signature. If it does not already exist it is added to this. */
+ Signature signature(String name) {
+ return signatures.computeIfAbsent(name, n -> new Signature(n));
+ }
+
+ /** Returns an immutable map of the arguments ("Placeholders") of this */
+ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
+
+ /** Returns an immutable map of the constants of this */
+ public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); }
+
+ /**
+ * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes
+ * which are not Placeholders or Variables (which instead become respectively arguments and constants).
+ * Note that only nodes recursively referenced by a placeholder are added.
+ */
+ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
+
+ /** Returns an immutable list, in natural sort order of the warnings generated while importing this */
+ public List<String> warnings() {
+ return warnings.stream().sorted().collect(Collectors.toList());
+ }
+
+ /** Returns an immutable map of the signatures of this */
+ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
+
+ /**
+ * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types,
+ * and outputs maps to expressions nodes.
+ */
+ public class Signature {
+
+ private final String name;
+ private final Map<String, String> inputs = new HashMap<>();
+ private final Map<String, String> outputs = new HashMap<>();
+
+ Signature(String name) {
+ this.name = name;
+ }
+
+ void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
+ void output(String name, String expressionName) { outputs.put(name, expressionName); }
+
+ /** Returns the result this is part of */
+ ImportResult owner() { return ImportResult.this; }
+
+ /**
+ * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
+ * to argument (Placeholder) name in the owner of this
+ */
+ public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
+
+ /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */
+ public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); }
+
+ /** Returns an immutable list of the expression names of this */
+ public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
+
+ /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */
+ public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); }
+
+ @Override
+ public String toString() { return "signature '" + name + "'"; }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
new file mode 100644
index 00000000000..bac141644c6
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -0,0 +1,160 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.VariableTensor;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Matmul;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.Softmax;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleUnaryOperator;
+
+/**
+ * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions.
+ *
+ * @author bratseth
+ */
+class OperationMapper {
+
+ /*
+ A note on conversion from implicitly numbered to explicitly named dimensions:
+ Vespa tensor dimensions are explicitly named and thus have an explicit notion of being
+ 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation
+ comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation
+ around dimension renaming operations which mirrors those built into the TF operation definitions.
+
+ To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost'
+ dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation
+ and the result is then renamed again (if necessary) to recover this convention across a full nested
+ computation.
+
+ This requires us to track tensor types throughout the conversion.
+ */
+
+ private TensorConverter tensorConverter = new TensorConverter();
+
+ TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) {
+ ensureArguments(2, arguments, "join");
+ TypedTensorFunction a = arguments.get(0);
+ TypedTensorFunction b = arguments.get(1);
+ if (a.type().rank() < b.type().rank())
+ throw new IllegalArgumentException("Attempt to join " + a.type() + " and " + b.type() + ", " +
+ "but this is not supported when the second argument has a higher rank");
+
+ TensorFunction bFunction = b.function();
+
+ if (a.type().rank() > b.type().rank()) {
+ // Well now we have entered the wonderful world of "broadcasting"
+ // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ // I'm not able to extract from that any unambiguous specification of which dimensions
+ // should be "stretched" when the tensor do not have the same number of dimensions.
+ // From trying this with TensorFlow it appears that the second tensor is matched to the
+ // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true.
+ // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first).
+ List<String> renameFrom = new ArrayList<>();
+ List<String> renameTo = new ArrayList<>();
+ int sizeDifference = a.type().rank() - b.type().rank();
+ for (int i = 0; i < b.type().rank(); i++) {
+ renameFrom.add(b.type().dimensions().get(i).name());
+ renameTo.add("d" + (sizeDifference + i));
+ }
+ bFunction = new Rename(bFunction, renameFrom, renameTo);
+ }
+
+ Join function = new Join(a.function(), bFunction, doubleFunction);
+ return new TypedTensorFunction(a.type(), function); // output type is a type by TF definition and a.rank>=b.rank
+ }
+
+ TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) {
+ ensureArguments(1, arguments, "apply");
+ TypedTensorFunction a = arguments.get(0);
+
+ TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type());
+ com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction);
+ return new TypedTensorFunction(resultType, function);
+ }
+
+ TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) {
+ String name = tfNode.getName();
+ TensorType type = result.arguments().get(name);
+ if (type == null)
+ throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name +
+ "', but there is no such placeholder");
+ // Included literally in the expression and so must be produced by a separate macro in the rank profile
+ return new TypedTensorFunction(type, new VariableTensor(name));
+ }
+
+ TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) {
+ if ( ! tfNode.getName().endsWith("/read"))
+ throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " +
+ "nodes are only supported when reading variables");
+ if (tfNode.getInputList().size() != 1)
+ throw new IllegalArgumentException("A Variable/read node must have one input but has " +
+ tfNode.getInputList().size());
+
+ String name = tfNode.getInput(0);
+ AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
+ if (shapes == null)
+ throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape");
+ Session.Runner fetched = model.session().runner().fetch(name);
+ List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
+ if ( importedTensors.size() != 1)
+ throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " +
+ importedTensors.size());
+ Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0));
+ result.constant(name, constant);
+ return new TypedTensorFunction(constant.type(),
+ new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
+ }
+
+ TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
+ ensureArguments(2, arguments, "matmul");
+ TypedTensorFunction a = arguments.get(0);
+ TypedTensorFunction b = arguments.get(1);
+ if (a.type().rank() < 2 || b.type().rank() < 2)
+ throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
+ if (a.type().rank() != b.type().rank())
+ throw new IllegalArgumentException("Tensors in matmul must have the same rank");
+
+ String afterLastDim = "d" + (a.type().rank() + 1);
+ // Let the first dimension of the second tensor be the same as the second dimension of the first
+ // and the second dimension of the second argument be not present in the first argument, while leaving the
+ // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication.
+
+ // TODO: Check if transpose_a or transpose_b is set true and rename differently accordingly
+
+ Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"),
+ ImmutableList.of("d1", afterLastDim));
+ Matmul matmul = new Matmul(a.function(), renamedB, "d1");
+ return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"),
+ new Rename(matmul, afterLastDim, "d1"));
+ }
+
+ TypedTensorFunction softmax(List<TypedTensorFunction> arguments) {
+ ensureArguments(1, arguments, "softmax");
+ TypedTensorFunction a = arguments.get(0);
+ // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1
+ String dimension = "d" + (a.type().rank() - 1);
+ Softmax softmax = new Softmax(a.function(), dimension);
+ return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax);
+ }
+
+ private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) {
+ if ( arguments.size() != count)
+ throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName +
+ ", but got " + arguments.size());
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
new file mode 100644
index 00000000000..1960cf94591
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
@@ -0,0 +1,94 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+
+/**
+ * @author bratseth
+ */
+public class TensorConverter {
+
+ public Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
+ TensorType type = toVespaTensorType(tfTensor.shape());
+ Values values = readValuesOf(tfTensor);
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
+ for (int i = 0; i < values.size(); i++)
+ builder.cellByDirectIndex(i, values.get(i));
+ return builder.build();
+ }
+
+ private TensorType toVespaTensorType(long[] shape) {
+ TensorType.Builder b = new TensorType.Builder();
+ int dimensionIndex = 0;
+ for (long dimensionSize : shape) {
+ if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
+ b.indexed("d" + (dimensionIndex++), dimensionSize);
+ }
+ return b.build();
+ }
+
+ private Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
+ switch (tfTensor.dataType()) {
+ case DOUBLE: return new DoubleValues(tfTensor);
+ case FLOAT: return new FloatValues(tfTensor);
+ // TODO: The rest
+ default:
+ throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
+ tfTensor.dataType() + " to a Vespa tensor");
+ }
+ }
+
+ /** Allows reading values from buffers of various numeric types as bytes */
+ private static abstract class Values {
+
+ private final int size;
+
+ protected Values(int size) {
+ this.size = size;
+ }
+
+ abstract double get(int i);
+
+ int size() { return size; }
+
+ }
+
+ private static class DoubleValues extends Values {
+
+ private final DoubleBuffer values;
+
+ DoubleValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = DoubleBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+
+ @Override
+ double get(int i) {
+ return values.get(i);
+ }
+
+ }
+
+ private static class FloatValues extends Values {
+
+ private final FloatBuffer values;
+
+ FloatValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = FloatBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+
+ @Override
+ double get(int i) {
+ return values.get(i);
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
new file mode 100644
index 00000000000..4a6551adca7
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -0,0 +1,145 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.yolean.Exceptions;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.framework.GraphDef;
+import org.tensorflow.framework.MetaGraphDef;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.SignatureDef;
+import org.tensorflow.framework.TensorInfo;
+import org.tensorflow.framework.TensorShapeProto;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * Converts a saved TensorFlow model into a ranking expression and set of constants.
+ *
+ * @author bratseth
+ */
+public class TensorFlowImporter {
+
+ private final OperationMapper operationMapper = new OperationMapper();
+
+ /**
+ * Imports a saved TensorFlow model from a directory.
+ * The model should be saved as a pbtxt file.
+ * The name of the model is taken as the db/pbtxt file name (not including the file ending).
+ *
+ * @param modelDir the directory containing the TensorFlow model files to import
+ */
+ public ImportResult importModel(String modelDir) {
+ try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
+ return importModel(model);
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
+ }
+ }
+
+ /** Imports a TensorFlow model */
+ public ImportResult importModel(SavedModelBundle model) {
+ try {
+ return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model);
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
+ }
+ }
+
+ private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) {
+ ImportResult result = new ImportResult();
+ for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
+ ImportResult.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName"
+
+ importInputs(signatureEntry.getValue().getInputsMap(), signature);
+ for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
+ String outputName = output.getKey();
+ try {
+ NodeDef node = getNode(nameOf(output.getValue().getName()), graph.getGraphDef());
+ importNode(node, graph.getGraphDef(), model, result);
+ signature.output(outputName, nameOf(output.getValue().getName()));
+ }
+ catch (IllegalArgumentException e) {
+ result.warn("Skipping output '" + outputName + "' of " + signature +
+ ": " + Exceptions.toMessageString(e));
+ }
+ }
+ }
+ return result;
+ }
+
+ private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult.Signature signature) {
+ inputInfoMap.forEach((key, value) -> {
+ String argumentName = nameOf(value.getName());
+ TensorType argumentType = importTensorType(value.getTensorShape());
+ // Arguments are (Placeholder) nodes, so not local to the signature:
+ signature.owner().argument(argumentName, argumentType);
+ signature.input(key, argumentName);
+ });
+ }
+
+ private TensorType importTensorType(TensorShapeProto tensorShape) {
+ TensorType.Builder b = new TensorType.Builder();
+ for (TensorShapeProto.Dim dimension : tensorShape.getDimList()) {
+ int dimensionSize = (int)dimension.getSize();
+ if (dimensionSize >= 0)
+ b.indexed("d" + b.rank(), dimensionSize);
+ else
+ b.indexed("d" + b.rank()); // unbound size
+ }
+ return b.build();
+ }
+
+ /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
+ private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result);
+ // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
+ // will be used
+ result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function())));
+ return function;
+ }
+
+ private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
+ // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
+ switch (tfNode.getOp().toLowerCase()) {
+ case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add());
+ case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos());
+ case "placeholder" : return operationMapper.placeholder(tfNode, result);
+ case "identity" : return operationMapper.identity(tfNode, model, result);
+ case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result));
+ case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result));
+ default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
+ }
+ }
+
+ private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
+ ImportResult result) {
+ return tfNode.getInputList().stream()
+ .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result))
+ .collect(Collectors.toList());
+ }
+
+ private NodeDef getNode(String name, GraphDef graph) {
+ return graph.getNodeList().stream()
+ .filter(node -> node.getName().equals(name))
+ .findFirst()
+ .orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'"));
+ }
+
+ /**
+ * A method signature input and output has the form name:index.
+ * This returns the name part without the index.
+ */
+ private String nameOf(String name) {
+ return name.split(":")[0];
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
new file mode 100644
index 00000000000..5712da77700
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
@@ -0,0 +1,24 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+/**
+ * A tensor function returning a specific tensor type
+ *
+ * @author bratseth
+ */
+final class TypedTensorFunction {
+
+ private final TensorType type;
+ private final TensorFunction function;
+
+ public TypedTensorFunction(TensorType type, TensorFunction function) {
+ this.type = type;
+ this.function = function;
+ }
+
+ public TensorType type() { return type; }
+ public TensorFunction function() { return function; }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
index 71699b379b2..9da1ba40144 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
@@ -1,7 +1,6 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
-import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -10,27 +9,26 @@ import com.yahoo.tensor.TensorType;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
-import java.util.function.*;
/**
* A tensor generating function, whose arguments are determined by a tensor type
- *
+ *
* @author bratseth
*/
public class GeneratorLambdaFunctionNode extends CompositeNode {
private final TensorType type;
private final ExpressionNode generator;
-
+
public GeneratorLambdaFunctionNode(TensorType type, ExpressionNode generator) {
if ( ! type.dimensions().stream().allMatch(d -> d.size().isPresent()))
- throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " +
+ throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " +
"dimensions, but tried to generate " + type);
// TODO: Verify that the function only accesses the given arguments
this.type = type;
this.generator = generator;
}
-
+
@Override
public List<ExpressionNode> children() {
return Collections.singletonList(generator);
@@ -53,24 +51,24 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
public Value evaluate(Context context) {
return generator.evaluate(context);
}
-
- /**
+
+ /**
* Returns this as an operator which converts a list of integers into a double
*/
- public IntegerListToDoubleLambda asIntegerListToDoubleOperator() {
- return new IntegerListToDoubleLambda();
+ public LongListToDoubleLambda asLongListToDoubleOperator() {
+ return new LongListToDoubleLambda();
}
- private class IntegerListToDoubleLambda implements java.util.function.Function<List<Integer>, Double> {
+ private class LongListToDoubleLambda implements java.util.function.Function<List<Long>, Double> {
@Override
- public Double apply(List<Integer> arguments) {
+ public Double apply(List<Long> arguments) {
MapContext context = new MapContext();
for (int i = 0; i < type.dimensions().size(); i++)
context.put(type.dimensions().get(i).name(), arguments.get(i));
return evaluate(context).asDouble();
}
-
+
@Override
public String toString() {
return GeneratorLambdaFunctionNode.this.toString();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
index 1f8db6e036c..ba765d07094 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
@@ -17,7 +17,7 @@ import java.util.Map;
* @author bratseth
*/
public class SerializationContext {
-
+
/** Expression functions indexed by name */
private final ImmutableMap<String, ExpressionFunction> functions;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index ce21e132980..8af3448ca6f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -21,22 +21,32 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
- @Beta
+@Beta
public class TensorFunctionNode extends CompositeNode {
private final TensorFunction function;
-
+
public TensorFunctionNode(TensorFunction function) {
this.function = function;
}
+ /** Returns the tensor function wrapped by this */
+ public TensorFunction function() { return function; }
+
@Override
public List<ExpressionNode> children() {
return function.functionArguments().stream()
- .map(f -> ((TensorFunctionExpressionNode)f).expression)
+ .map(this::toExpressionNode)
.collect(Collectors.toList());
}
+ private ExpressionNode toExpressionNode(TensorFunction f) {
+ if (f instanceof TensorFunctionExpressionNode)
+ return ((TensorFunctionExpressionNode)f).expression;
+ else
+ return new TensorFunctionNode(f);
+ }
+
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
List<TensorFunction> wrappedChildren = children.stream()
@@ -50,7 +60,7 @@ public class TensorFunctionNode extends CompositeNode {
// Serialize as primitive
return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this));
}
-
+
@Override
public Value evaluate(Context context) {
return new TensorValue(function.evaluate(context));
@@ -59,8 +69,8 @@ public class TensorFunctionNode extends CompositeNode {
public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) {
return new TensorFunctionExpressionNode(node);
}
-
- /**
+
+ /**
* A tensor function implemented by an expression.
* This allows us to pass expressions as tensor function arguments.
*/
@@ -68,13 +78,13 @@ public class TensorFunctionNode extends CompositeNode {
/** An expression which produces a tensor */
private final ExpressionNode expression;
-
+
public TensorFunctionExpressionNode(ExpressionNode expression) {
this.expression = expression;
}
-
+
@Override
- public List<TensorFunction> functionArguments() {
+ public List<TensorFunction> functionArguments() {
if (expression instanceof CompositeNode)
return ((CompositeNode)expression).children().stream()
.map(TensorFunctionExpressionNode::new)
@@ -108,7 +118,7 @@ public class TensorFunctionNode extends CompositeNode {
public String toString() {
return toString(ExpressionNodeToStringContext.empty);
}
-
+
@Override
public String toString(ToStringContext c) {
if (c instanceof ExpressionNodeToStringContext) {
@@ -121,14 +131,14 @@ public class TensorFunctionNode extends CompositeNode {
}
}
-
+
/** Allows passing serialization context arguments through TensorFunctions */
private static class ExpressionNodeToStringContext implements ToStringContext {
-
+
final SerializationContext context;
final Deque<String> path;
final CompositeNode parent;
-
+
public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(null, null, null);
public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 7821ab88b86..541738db8e0 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -467,7 +467,7 @@ ExpressionNode tensorGenerate() :
}
{
<TENSOR> type = tensorTypeArgument() <LBRACE> generator = expression() <RBRACE>
- { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asIntegerListToDoubleOperator())); }
+ { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asLongListToDoubleOperator())); }
}
ExpressionNode tensorRange() :
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
new file mode 100644
index 00000000000..a1861a1c981
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
@@ -0,0 +1,89 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A very simple MNIST classifier.
+
+See extensive documentation at
+https://www.tensorflow.org/get_started/mnist/beginners
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+from tensorflow.examples.tutorials.mnist import input_data
+
+import tensorflow as tf
+
+FLAGS = None
+
+
+def main(_):
+ # Import data
+ mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
+
+ # Create the model
+ x = tf.placeholder(tf.float32, [None, 784])
+ W = tf.Variable(tf.zeros([784, 10]))
+ b = tf.Variable(tf.zeros([10]))
+ y = tf.matmul(x, W) + b
+
+ # Define loss and optimizer
+ y_ = tf.placeholder(tf.float32, [None, 10])
+
+ # The raw formulation of cross-entropy,
+ #
+ # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
+ # reduction_indices=[1]))
+ #
+ # can be numerically unstable.
+ #
+ # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
+ # outputs of 'y', and then average across the batch.
+ cross_entropy = tf.reduce_mean(
+ tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
+ train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
+
+ sess = tf.InteractiveSession()
+ tf.global_variables_initializer().run()
+ # Train
+ for _ in range(1000):
+ batch_xs, batch_ys = mnist.train.next_batch(100)
+ sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
+
+ # Test trained model
+ correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
+ accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+ print(sess.run(accuracy, feed_dict={x: mnist.test.images,
+ y_: mnist.test.labels}))
+
+ # Save the model
+ export_path = "saved"
+ print('Exporting trained model to ', export_path)
+ builder = tf.saved_model.builder.SavedModelBuilder(export_path)
+ signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':x}, outputs = {'y':y})
+ builder.add_meta_graph_and_variables(sess,
+ [tf.saved_model.tag_constants.SERVING],
+ signature_def_map={'serving_default':signature})
+ builder.save(as_text=True)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
+ help='Directory for storing input data')
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt
new file mode 100644
index 00000000000..8100dfd594d
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt
@@ -0,0 +1,5039 @@
+saved_model_schema_version: 1
+meta_graphs {
+ meta_info_def {
+ stripped_op_list {
+ op {
+ name: "Add"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_STRING
+ }
+ }
+ }
+ }
+ op {
+ name: "ApplyGradientDescent"
+ input_arg {
+ name: "var"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "alpha"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "delta"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "out"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ }
+ op {
+ name: "ArgMax"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dimension"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "output_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "output_type"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Assign"
+ input_arg {
+ name: "ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "value"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output_ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "validate_shape"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ allows_uninitialized_input: true
+ }
+ op {
+ name: "BroadcastGradientArgs"
+ input_arg {
+ name: "s0"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "s1"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "r0"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "r1"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Cast"
+ input_arg {
+ name: "x"
+ type_attr: "SrcT"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "DstT"
+ }
+ attr {
+ name: "SrcT"
+ type: "type"
+ }
+ attr {
+ name: "DstT"
+ type: "type"
+ }
+ }
+ op {
+ name: "ConcatV2"
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ number_attr: "N"
+ }
+ input_arg {
+ name: "axis"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 2
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Const"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "value"
+ type: "tensor"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ }
+ op {
+ name: "Equal"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type: DT_BOOL
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_QUINT8
+ type: DT_QINT8
+ type: DT_QINT32
+ type: DT_STRING
+ type: DT_BOOL
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ is_commutative: true
+ }
+ op {
+ name: "ExpandDims"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dim"
+ type_attr: "Tdim"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tdim"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Fill"
+ input_arg {
+ name: "dims"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "value"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ }
+ op {
+ name: "FloorDiv"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Identity"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ }
+ op {
+ name: "MatMul"
+ input_arg {
+ name: "a"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "b"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "product"
+ type_attr: "T"
+ }
+ attr {
+ name: "transpose_a"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "transpose_b"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Maximum"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_commutative: true
+ }
+ op {
+ name: "Mean"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reduction_indices"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "keep_dims"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "MergeV2Checkpoints"
+ input_arg {
+ name: "checkpoint_prefixes"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "destination_prefix"
+ type: DT_STRING
+ }
+ attr {
+ name: "delete_old_dirs"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ is_stateful: true
+ }
+ op {
+ name: "Mul"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ is_commutative: true
+ }
+ op {
+ name: "NoOp"
+ }
+ op {
+ name: "Pack"
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ number_attr: "N"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "axis"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ }
+ op {
+ name: "Placeholder"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ default_value {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ op {
+ name: "Prod"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reduction_indices"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "keep_dims"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "RealDiv"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Reshape"
+ input_arg {
+ name: "tensor"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "shape"
+ type_attr: "Tshape"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tshape"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "RestoreV2"
+ input_arg {
+ name: "prefix"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "tensor_names"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "shape_and_slices"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "tensors"
+ type_list_attr: "dtypes"
+ }
+ attr {
+ name: "dtypes"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+ }
+ op {
+ name: "SaveV2"
+ input_arg {
+ name: "prefix"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "tensor_names"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "shape_and_slices"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "tensors"
+ type_list_attr: "dtypes"
+ }
+ attr {
+ name: "dtypes"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+ }
+ op {
+ name: "Shape"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "ShardedFilename"
+ input_arg {
+ name: "basename"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "shard"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "num_shards"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "filename"
+ type: DT_STRING
+ }
+ }
+ op {
+ name: "Slice"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "begin"
+ type_attr: "Index"
+ }
+ input_arg {
+ name: "size"
+ type_attr: "Index"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Index"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "SoftmaxCrossEntropyWithLogits"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "labels"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "loss"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprop"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ }
+ op {
+ name: "StringJoin"
+ input_arg {
+ name: "inputs"
+ type: DT_STRING
+ number_attr: "N"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "separator"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ }
+ op {
+ name: "Sub"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Sum"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reduction_indices"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "keep_dims"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "Tile"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "multiples"
+ type_attr: "Tmultiples"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tmultiples"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "VariableV2"
+ output_arg {
+ name: "ref"
+ type_attr: "dtype"
+ is_ref: true
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+ }
+ op {
+ name: "ZerosLike"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ }
+ }
+ tags: "serve"
+ tensorflow_version: "1.4.1"
+ tensorflow_git_version: "v1.4.0-19-ga52c8d9b01"
+ }
+ graph_def {
+ node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "zeros"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ float_val: 0.0
+ }
+ }
+ }
+ }
+ node {
+ name: "Variable"
+ op: "VariableV2"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+ }
+ node {
+ name: "Variable/Assign"
+ op: "Assign"
+ input: "Variable"
+ input: "zeros"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "Variable/read"
+ op: "Identity"
+ input: "Variable"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "zeros_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 10
+ }
+ }
+ float_val: 0.0
+ }
+ }
+ }
+ }
+ node {
+ name: "Variable_1"
+ op: "VariableV2"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+ }
+ node {
+ name: "Variable_1/Assign"
+ op: "Assign"
+ input: "Variable_1"
+ input: "zeros_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "Variable_1/read"
+ op: "Identity"
+ input: "Variable_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "MatMul"
+ op: "MatMul"
+ input: "Placeholder"
+ input: "Variable/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "add"
+ op: "Add"
+ input: "MatMul"
+ input: "Variable_1/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Placeholder_1"
+ op: "Placeholder"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Rank"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 2
+ }
+ }
+ }
+ }
+ node {
+ name: "Shape"
+ op: "Shape"
+ input: "add"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "Rank_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 2
+ }
+ }
+ }
+ }
+ node {
+ name: "Shape_1"
+ op: "Shape"
+ input: "add"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "Sub/y"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Sub"
+ op: "Sub"
+ input: "Rank_1"
+ input: "Sub/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice/begin"
+ op: "Pack"
+ input: "Sub"
+ attr {
+ key: "N"
+ value {
+ i: 1
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
+ }
+ node {
+ name: "Slice/size"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice"
+ op: "Slice"
+ input: "Shape_1"
+ input: "Slice/begin"
+ input: "Slice/size"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "concat/values_0"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: -1
+ }
+ }
+ }
+ }
+ node {
+ name: "concat/axis"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "concat"
+ op: "ConcatV2"
+ input: "concat/values_0"
+ input: "Slice"
+ input: "concat/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Reshape"
+ op: "Reshape"
+ input: "add"
+ input: "concat"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Rank_2"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 2
+ }
+ }
+ }
+ }
+ node {
+ name: "Shape_2"
+ op: "Shape"
+ input: "Placeholder_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "Sub_1/y"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Sub_1"
+ op: "Sub"
+ input: "Rank_2"
+ input: "Sub_1/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice_1/begin"
+ op: "Pack"
+ input: "Sub_1"
+ attr {
+ key: "N"
+ value {
+ i: 1
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
+ }
+ node {
+ name: "Slice_1/size"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice_1"
+ op: "Slice"
+ input: "Shape_2"
+ input: "Slice_1/begin"
+ input: "Slice_1/size"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_1/values_0"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: -1
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_1/axis"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_1"
+ op: "ConcatV2"
+ input: "concat_1/values_0"
+ input: "Slice_1"
+ input: "concat_1/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Reshape_1"
+ op: "Reshape"
+ input: "Placeholder_1"
+ input: "concat_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "SoftmaxCrossEntropyWithLogits"
+ op: "SoftmaxCrossEntropyWithLogits"
+ input: "Reshape"
+ input: "Reshape_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Sub_2/y"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "Sub_2"
+ op: "Sub"
+ input: "Rank"
+ input: "Sub_2/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice_2/begin"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "Slice_2/size"
+ op: "Pack"
+ input: "Sub_2"
+ attr {
+ key: "N"
+ value {
+ i: 1
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
+ }
+ node {
+ name: "Slice_2"
+ op: "Slice"
+ input: "Shape"
+ input: "Slice_2/begin"
+ input: "Slice_2/size"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Reshape_2"
+ op: "Reshape"
+ input: "SoftmaxCrossEntropyWithLogits"
+ input: "Slice_2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "Mean"
+ op: "Mean"
+ input: "Reshape_2"
+ input: "Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/Shape"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Const"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Fill"
+ op: "Fill"
+ input: "gradients/Shape"
+ input: "gradients/Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Reshape/shape"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/Fill"
+ input: "gradients/Mean_grad/Reshape/shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Shape"
+ op: "Shape"
+ input: "Reshape_2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Tile"
+ op: "Tile"
+ input: "gradients/Mean_grad/Reshape"
+ input: "gradients/Mean_grad/Shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tmultiples"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Shape_1"
+ op: "Shape"
+ input: "Reshape_2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Shape_2"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Const"
+ op: "Const"
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Prod"
+ op: "Prod"
+ input: "gradients/Mean_grad/Shape_1"
+ input: "gradients/Mean_grad/Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Const_1"
+ op: "Const"
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Prod_1"
+ op: "Prod"
+ input: "gradients/Mean_grad/Shape_2"
+ input: "gradients/Mean_grad/Const_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Maximum/y"
+ op: "Const"
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Maximum"
+ op: "Maximum"
+ input: "gradients/Mean_grad/Prod_1"
+ input: "gradients/Mean_grad/Maximum/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/floordiv"
+ op: "FloorDiv"
+ input: "gradients/Mean_grad/Prod"
+ input: "gradients/Mean_grad/Maximum"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/Mean_grad/Shape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/Cast"
+ op: "Cast"
+ input: "gradients/Mean_grad/floordiv"
+ attr {
+ key: "DstT"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "SrcT"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Mean_grad/truediv"
+ op: "RealDiv"
+ input: "gradients/Mean_grad/Tile"
+ input: "gradients/Mean_grad/Cast"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Reshape_2_grad/Shape"
+ op: "Shape"
+ input: "SoftmaxCrossEntropyWithLogits"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/Reshape_2_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/Mean_grad/truediv"
+ input: "gradients/Reshape_2_grad/Shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/zeros_like"
+ op: "ZerosLike"
+ input: "SoftmaxCrossEntropyWithLogits:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: -1
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims"
+ op: "ExpandDims"
+ input: "gradients/Reshape_2_grad/Reshape"
+ input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tdim"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul"
+ op: "Mul"
+ input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims"
+ input: "SoftmaxCrossEntropyWithLogits:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/Reshape_grad/Shape"
+ op: "Shape"
+ input: "add"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/Reshape_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul"
+ input: "gradients/Reshape_grad/Shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Shape"
+ op: "Shape"
+ input: "MatMul"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Shape_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 10
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/BroadcastGradientArgs"
+ op: "BroadcastGradientArgs"
+ input: "gradients/add_grad/Shape"
+ input: "gradients/add_grad/Shape_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Sum"
+ op: "Sum"
+ input: "gradients/Reshape_grad/Reshape"
+ input: "gradients/add_grad/BroadcastGradientArgs"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Reshape"
+ op: "Reshape"
+ input: "gradients/add_grad/Sum"
+ input: "gradients/add_grad/Shape"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Sum_1"
+ op: "Sum"
+ input: "gradients/Reshape_grad/Reshape"
+ input: "gradients/add_grad/BroadcastGradientArgs:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/Reshape_1"
+ op: "Reshape"
+ input: "gradients/add_grad/Sum_1"
+ input: "gradients/add_grad/Shape_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/tuple/group_deps"
+ op: "NoOp"
+ input: "^gradients/add_grad/Reshape"
+ input: "^gradients/add_grad/Reshape_1"
+ }
+ node {
+ name: "gradients/add_grad/tuple/control_dependency"
+ op: "Identity"
+ input: "gradients/add_grad/Reshape"
+ input: "^gradients/add_grad/tuple/group_deps"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/add_grad/Reshape"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/add_grad/tuple/control_dependency_1"
+ op: "Identity"
+ input: "gradients/add_grad/Reshape_1"
+ input: "^gradients/add_grad/tuple/group_deps"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/add_grad/Reshape_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/MatMul_grad/MatMul"
+ op: "MatMul"
+ input: "gradients/add_grad/tuple/control_dependency"
+ input: "Variable/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "gradients/MatMul_grad/MatMul_1"
+ op: "MatMul"
+ input: "Placeholder"
+ input: "gradients/add_grad/tuple/control_dependency"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "gradients/MatMul_grad/tuple/group_deps"
+ op: "NoOp"
+ input: "^gradients/MatMul_grad/MatMul"
+ input: "^gradients/MatMul_grad/MatMul_1"
+ }
+ node {
+ name: "gradients/MatMul_grad/tuple/control_dependency"
+ op: "Identity"
+ input: "gradients/MatMul_grad/MatMul"
+ input: "^gradients/MatMul_grad/tuple/group_deps"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/MatMul_grad/MatMul"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "gradients/MatMul_grad/tuple/control_dependency_1"
+ op: "Identity"
+ input: "gradients/MatMul_grad/MatMul_1"
+ input: "^gradients/MatMul_grad/tuple/group_deps"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@gradients/MatMul_grad/MatMul_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "GradientDescent/learning_rate"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.5
+ }
+ }
+ }
+ }
+ node {
+ name: "GradientDescent/update_Variable/ApplyGradientDescent"
+ op: "ApplyGradientDescent"
+ input: "Variable"
+ input: "GradientDescent/learning_rate"
+ input: "gradients/MatMul_grad/tuple/control_dependency_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "GradientDescent/update_Variable_1/ApplyGradientDescent"
+ op: "ApplyGradientDescent"
+ input: "Variable_1"
+ input: "GradientDescent/learning_rate"
+ input: "gradients/add_grad/tuple/control_dependency_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "GradientDescent"
+ op: "NoOp"
+ input: "^GradientDescent/update_Variable/ApplyGradientDescent"
+ input: "^GradientDescent/update_Variable_1/ApplyGradientDescent"
+ }
+ node {
+ name: "init"
+ op: "NoOp"
+ input: "^Variable/Assign"
+ input: "^Variable_1/Assign"
+ }
+ node {
+ name: "ArgMax/dimension"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "ArgMax"
+ op: "ArgMax"
+ input: "add"
+ input: "ArgMax/dimension"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_type"
+ value {
+ type: DT_INT64
+ }
+ }
+ }
+ node {
+ name: "ArgMax_1/dimension"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "ArgMax_1"
+ op: "ArgMax"
+ input: "Placeholder_1"
+ input: "ArgMax_1/dimension"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_type"
+ value {
+ type: DT_INT64
+ }
+ }
+ }
+ node {
+ name: "Equal"
+ op: "Equal"
+ input: "ArgMax"
+ input: "ArgMax_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Cast_1"
+ op: "Cast"
+ input: "Equal"
+ attr {
+ key: "DstT"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "SrcT"
+ value {
+ type: DT_BOOL
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "Mean_1"
+ op: "Mean"
+ input: "Cast_1"
+ input: "Const_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "save/Const"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ }
+ string_val: "model"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/StringJoin/inputs_1"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ }
+ string_val: "_temp_6ca9fa5171ed4237a2fbcc27277e2864/part"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/StringJoin"
+ op: "StringJoin"
+ input: "save/Const"
+ input: "save/StringJoin/inputs_1"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "separator"
+ value {
+ s: ""
+ }
+ }
+ }
+ node {
+ name: "save/num_shards"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "save/ShardedFilename/shard"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node {
+ name: "save/ShardedFilename"
+ op: "ShardedFilename"
+ input: "save/StringJoin"
+ input: "save/ShardedFilename/shard"
+ input: "save/num_shards"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "save/SaveV2/tensor_names"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ string_val: "Variable"
+ string_val: "Variable_1"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/SaveV2/shape_and_slices"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ string_val: ""
+ string_val: ""
+ }
+ }
+ }
+ }
+ node {
+ name: "save/SaveV2"
+ op: "SaveV2"
+ input: "save/ShardedFilename"
+ input: "save/SaveV2/tensor_names"
+ input: "save/SaveV2/shape_and_slices"
+ input: "Variable"
+ input: "Variable_1"
+ attr {
+ key: "dtypes"
+ value {
+ list {
+ type: DT_FLOAT
+ type: DT_FLOAT
+ }
+ }
+ }
+ }
+ node {
+ name: "save/control_dependency"
+ op: "Identity"
+ input: "save/ShardedFilename"
+ input: "^save/SaveV2"
+ attr {
+ key: "T"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@save/ShardedFilename"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "save/MergeV2Checkpoints/checkpoint_prefixes"
+ op: "Pack"
+ input: "save/ShardedFilename"
+ input: "^save/control_dependency"
+ attr {
+ key: "N"
+ value {
+ i: 1
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "axis"
+ value {
+ i: 0
+ }
+ }
+ }
+ node {
+ name: "save/MergeV2Checkpoints"
+ op: "MergeV2Checkpoints"
+ input: "save/MergeV2Checkpoints/checkpoint_prefixes"
+ input: "save/Const"
+ attr {
+ key: "delete_old_dirs"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "save/Identity"
+ op: "Identity"
+ input: "save/Const"
+ input: "^save/control_dependency"
+ input: "^save/MergeV2Checkpoints"
+ attr {
+ key: "T"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2/tensor_names"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ string_val: "Variable"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2/shape_and_slices"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ string_val: ""
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2"
+ op: "RestoreV2"
+ input: "save/Const"
+ input: "save/RestoreV2/tensor_names"
+ input: "save/RestoreV2/shape_and_slices"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtypes"
+ value {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ }
+ node {
+ name: "save/Assign"
+ op: "Assign"
+ input: "Variable"
+ input: "save/RestoreV2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2_1/tensor_names"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ string_val: "Variable_1"
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2_1/shape_and_slices"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ string_val: ""
+ }
+ }
+ }
+ }
+ node {
+ name: "save/RestoreV2_1"
+ op: "RestoreV2"
+ input: "save/Const"
+ input: "save/RestoreV2_1/tensor_names"
+ input: "save/RestoreV2_1/shape_and_slices"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtypes"
+ value {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ }
+ node {
+ name: "save/Assign_1"
+ op: "Assign"
+ input: "Variable_1"
+ input: "save/RestoreV2_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@Variable_1"
+ }
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "save/restore_shard"
+ op: "NoOp"
+ input: "^save/Assign"
+ input: "^save/Assign_1"
+ }
+ node {
+ name: "save/restore_all"
+ op: "NoOp"
+ input: "^save/restore_shard"
+ }
+ versions {
+ producer: 24
+ }
+ }
+ saver_def {
+ filename_tensor_name: "save/Const:0"
+ save_tensor_name: "save/Identity:0"
+ restore_op_name: "save/restore_all"
+ max_to_keep: 5
+ sharded: true
+ keep_checkpoint_every_n_hours: 10000.0
+ version: V2
+ }
+ collection_def {
+ key: "train_op"
+ value {
+ node_list {
+ value: "GradientDescent"
+ }
+ }
+ }
+ collection_def {
+ key: "trainable_variables"
+ value {
+ bytes_list {
+ value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0"
+ value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0"
+ }
+ }
+ }
+ collection_def {
+ key: "variables"
+ value {
+ bytes_list {
+ value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0"
+ value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0"
+ }
+ }
+ }
+ signature_def {
+ key: "serving_default"
+ value {
+ inputs {
+ key: "x"
+ value {
+ name: "Placeholder:0"
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ outputs {
+ key: "y"
+ value {
+ name: "add:0"
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ method_name: "tensorflow/serving/predict"
+ }
+ }
+}
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
new file mode 100644
index 00000000000..8474aa0a04c
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index
new file mode 100644
index 00000000000..cfcdac20409
--- /dev/null
+++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index
Binary files differ
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 82e5d0cfe5b..3aa2d144f1f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -4,9 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.rule.*;
-import com.yahoo.tensor.Tensor;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import org.junit.Test;
+
import static org.junit.Assert.assertEquals;
/**
@@ -83,7 +88,7 @@ public class EvaluationTestCase {
tester.assertEvaluates(0, "sin(0)");
tester.assertEvaluates(1, "cos(0)");
tester.assertEvaluates(8, "pow(4/2,min(cos(0)*3,5))");
-
+
// Random feature (which is also a tensor function) (We expect to be able to parse it and look up a zero)
tester.assertEvaluates(0, "random(1)");
tester.assertEvaluates(0, "random(foo)");
@@ -152,7 +157,7 @@ public class EvaluationTestCase {
"tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
"!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }");
-
+
// -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence)
tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }");
tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }");
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
index ee2b1c147e3..ba0db4de5e1 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
@@ -34,7 +34,7 @@ public class EvaluationTester {
}
// TODO: Test both bound and unbound indexed
- public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors,
+ public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors,
String ... tensorArgumentStrings) {
MapContext context = defaultContext.thawedCopy();
int argumentIndex = 0;
@@ -46,7 +46,7 @@ public class EvaluationTester {
argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString);
context.put("tensor" + (argumentIndex++), new TensorValue(argument));
}
- return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context,
+ return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context,
mappedTensors ? "Mapped tensors" : "Indexed tensors");
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
new file mode 100644
index 00000000000..0370fc7fc94
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
@@ -0,0 +1,118 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+
+import java.nio.FloatBuffer;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * @author bratseth
+ */
+public class Mnist_SoftmaxTestCase {
+
+ @Ignore
+ @Test
+ public void testImporting() {
+ String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
+ SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
+ ImportResult result = new TensorFlowImporter().importModel(model);
+
+ // Check logged messages
+ result.warnings().forEach(System.err::println);
+ assertEquals(0, result.warnings().size());
+
+ // Check constants
+ assertEquals(2, result.constants().size());
+
+ Tensor constant0 = result.constants().get("Variable");
+ assertNotNull(constant0);
+ assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
+ constant0.type());
+ assertEquals(7840, constant0.size());
+
+ Tensor constant1 = result.constants().get("Variable_1");
+ assertNotNull(constant1);
+ assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
+ constant1.type());
+ assertEquals(10, constant1.size());
+
+ // Check signatures
+ assertEquals(1, result.signatures().size());
+ ImportResult.Signature signature = result.signatures().get("serving_default");
+ assertNotNull(signature);
+
+ // ... signature inputs
+ assertEquals(1, signature.inputs().size());
+ TensorType argument0 = signature.inputArgument("x");
+ assertNotNull(argument0);
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
+
+ // ... signature outputs
+ assertEquals(1, signature.outputs().size());
+ RankingExpression output = signature.outputExpression("y");
+ assertNotNull(output);
+ assertEquals("add", output.getName());
+ assertEquals("" +
+ "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
+ "rename(constant(Variable_1), d0, d1), " +
+ "f(a,b)(a + b))",
+ toNonPrimitiveString(output));
+
+ // Test execution
+ assertEqualResult(model, result, "Variable/read");
+ assertEqualResult(model, result, "Variable_1/read");
+ assertEqualResult(model, result, "MatMul");
+ assertEqualResult(model, result, "add");
+ }
+
+ private void assertEqualResult(SavedModelBundle model, ImportResult result, String operationName) {
+ Tensor tfResult = tensorFlowExecute(model, operationName);
+ Context context = contextFrom(result);
+ Tensor placeholder = placeholderArgument();
+ context.put("Placeholder", new TensorValue(placeholder));
+ Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor();
+ assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
+ }
+
+ private Tensor tensorFlowExecute(SavedModelBundle model, String operationName) {
+ Session.Runner runner = model.session().runner();
+ org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784));
+ runner.feed("Placeholder", placeholder);
+ List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
+ assertEquals(1, results.size());
+ return new TensorConverter().toVespaTensor(results.get(0));
+ }
+
+ private Context contextFrom(ImportResult result) {
+ MapContext context = new MapContext();
+ result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ return context;
+ }
+
+ private String toNonPrimitiveString(RankingExpression expression) {
+ // toString on the wrapping expression will map to primitives, which is harder to read
+ return ((TensorFunctionNode)expression.getRoot()).function().toString();
+ }
+
+ private Tensor placeholderArgument() {
+ int size = 784;
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build());
+ for (int i = 0; i < size; i++)
+ b.cell(0, 0, i);
+ return b.build();
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
index dde9d4bf21e..1960c1fe876 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
@@ -59,7 +59,7 @@ public class TensorConformanceTest {
try {
ObjectMapper mapper = new ObjectMapper();
JsonNode node = mapper.readTree(test);
-
+
if (node.has("num_tests")) {
Assert.assertEquals(node.get("num_tests").asInt(), count);
return true;
@@ -67,7 +67,7 @@ public class TensorConformanceTest {
if (!node.has("expression")) {
return true; // ignore
}
-
+
String expression = node.get("expression").asText();
MapContext context = getInput(node.get("inputs"));
Tensor expect = getTensor(node.get("result").get("expect").asText());
diff --git a/searchlib/src/vespa/searchlib/attribute/postingchange.cpp b/searchlib/src/vespa/searchlib/attribute/postingchange.cpp
index 9957d162d9d..702ff0fc5cf 100644
--- a/searchlib/src/vespa/searchlib/attribute/postingchange.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/postingchange.cpp
@@ -6,6 +6,7 @@
#include "postinglistattribute.h"
#include <vespa/searchlib/common/growablebitvector.h>
#include <vespa/vespalib/util/array.hpp>
+#include <vespa/vespalib/stllike/hash_map.hpp>
namespace search {
diff --git a/searchlib/src/vespa/searchlib/common/foregroundtaskexecutor.cpp b/searchlib/src/vespa/searchlib/common/foregroundtaskexecutor.cpp
index 990117a71ce..1ab3c6b8b51 100644
--- a/searchlib/src/vespa/searchlib/common/foregroundtaskexecutor.cpp
+++ b/searchlib/src/vespa/searchlib/common/foregroundtaskexecutor.cpp
@@ -2,6 +2,7 @@
#include "foregroundtaskexecutor.h"
#include <vespa/vespalib/util/threadstackexecutor.h>
+#include <vespa/vespalib/stllike/hash_map.hpp>
using vespalib::ThreadStackExecutor;
diff --git a/searchlib/src/vespa/searchlib/common/sequencedtaskexecutor.cpp b/searchlib/src/vespa/searchlib/common/sequencedtaskexecutor.cpp
index 45004db2615..446c9ec39ec 100644
--- a/searchlib/src/vespa/searchlib/common/sequencedtaskexecutor.cpp
+++ b/searchlib/src/vespa/searchlib/common/sequencedtaskexecutor.cpp
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "sequencedtaskexecutor.h"
+#include <vespa/vespalib/stllike/hash_map.hpp>
using vespalib::BlockingThreadStackExecutor;
diff --git a/searchlib/src/vespa/searchlib/docstore/visitcache.cpp b/searchlib/src/vespa/searchlib/docstore/visitcache.cpp
index 8f73c9862ae..7e881d8de76 100644
--- a/searchlib/src/vespa/searchlib/docstore/visitcache.cpp
+++ b/searchlib/src/vespa/searchlib/docstore/visitcache.cpp
@@ -209,8 +209,7 @@ VisitCache::Cache::locateAndInvalidateOtherSubsets(const LockGuard & cacheGuard,
CompressedBlobSet
VisitCache::read(const IDocumentStore::LidVector & lids) const {
- KeySet key(lids);
- return _cache->readSet(lids);
+ return _cache->readSet(KeySet(lids));
}
void
diff --git a/searchlib/src/vespa/searchlib/docstore/visitcache.h b/searchlib/src/vespa/searchlib/docstore/visitcache.h
index 1bf867c5580..effc6c19a21 100644
--- a/searchlib/src/vespa/searchlib/docstore/visitcache.h
+++ b/searchlib/src/vespa/searchlib/docstore/visitcache.h
@@ -20,7 +20,7 @@ class KeySet {
public:
KeySet() : _keys() { }
KeySet(uint32_t key);
- KeySet(const IDocumentStore::LidVector &keys);
+ explicit KeySet(const IDocumentStore::LidVector &keys);
uint32_t hash() const { return _keys.empty() ? 0 : _keys[0]; }
bool operator==(const KeySet &rhs) const { return _keys == rhs._keys; }
bool operator<(const KeySet &rhs) const { return _keys < rhs._keys; }
diff --git a/staging_vespalib/src/vespa/vespalib/stllike/lrucache_map.hpp b/staging_vespalib/src/vespa/vespalib/stllike/lrucache_map.hpp
index fe57de093dd..61147229497 100644
--- a/staging_vespalib/src/vespa/vespalib/stllike/lrucache_map.hpp
+++ b/staging_vespalib/src/vespa/vespalib/stllike/lrucache_map.hpp
@@ -110,6 +110,7 @@ void
lrucache_map<P>::erase(const K & key) {
internal_iterator it = HashTable::find(key);
if (it != HashTable::end()) {
+ next_t h = HashTable::hash(key);
onRemove(key);
LV & v = it->second;
if (v._prev != LinkedValueBase::npos) {
@@ -122,7 +123,7 @@ lrucache_map<P>::erase(const K & key) {
} else {
_tail = v._prev;
}
- HashTable::erase(*this, it);
+ HashTable::erase(*this, h, it);
}
}
@@ -202,7 +203,7 @@ lrucache_map<P>::removeOld() {
{
_tail = last->second._prev;
HashTable::getByInternalIndex(_tail).second._next = LinkedValueBase::npos;
- HashTable::erase(*this, HashTable::find(last->first));
+ HashTable::erase(*this, HashTable::hash(last->first), HashTable::find(last->first));
}
}
}
diff --git a/storage/src/tests/distributor/blockingoperationstartertest.cpp b/storage/src/tests/distributor/blockingoperationstartertest.cpp
index c2fdc25cebf..0160f5c9e51 100644
--- a/storage/src/tests/distributor/blockingoperationstartertest.cpp
+++ b/storage/src/tests/distributor/blockingoperationstartertest.cpp
@@ -60,7 +60,7 @@ BlockingOperationStarterTest::testOperationNotBlockedWhenNoMessagesPending()
{
CPPUNIT_ASSERT(_operationStarter->start(createMockOperation(),
OperationStarter::Priority(0)));
- CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri 0\n"),
+ CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri 0\n"),
_starterImpl->toString());
}
diff --git a/storage/src/tests/distributor/bucketdbupdatertest.cpp b/storage/src/tests/distributor/bucketdbupdatertest.cpp
index b9e33ea8d26..ff442114c4c 100644
--- a/storage/src/tests/distributor/bucketdbupdatertest.cpp
+++ b/storage/src/tests/distributor/bucketdbupdatertest.cpp
@@ -4,6 +4,7 @@
#include <iomanip>
#include <vespa/storageapi/message/persistence.h>
#include <vespa/storage/distributor/bucketdbupdater.h>
+#include <vespa/storage/distributor/distributormetricsset.h>
#include <vespa/storage/distributor/pending_bucket_space_db_transition.h>
#include <vespa/storage/distributor/outdated_nodes_map.h>
#include <vespa/vespalib/io/fileutil.h>
@@ -22,8 +23,7 @@ using namespace storage::lib;
using document::test::makeDocumentBucket;
using document::test::makeBucketSpace;
-namespace storage {
-namespace distributor {
+namespace storage::distributor {
class BucketDBUpdaterTest : public CppUnit::TestFixture,
public DistributorTestUtil
@@ -2499,5 +2499,4 @@ void BucketDBUpdaterTest::batch_update_from_distributor_change_does_not_mark_div
"0:5/1/2/3|1:5/7/8/9", true));
}
-} // distributor
-} // storage
+}
diff --git a/storage/src/tests/distributor/externaloperationhandlertest.cpp b/storage/src/tests/distributor/externaloperationhandlertest.cpp
index 683352e6b09..a0b8cd424ac 100644
--- a/storage/src/tests/distributor/externaloperationhandlertest.cpp
+++ b/storage/src/tests/distributor/externaloperationhandlertest.cpp
@@ -2,15 +2,14 @@
#include <tests/distributor/distributortestutil.h>
#include <vespa/storage/distributor/externaloperationhandler.h>
-#include <vespa/storage/distributor/operation_sequencer.h>
-#include <vespa/storageapi/message/persistence.h>
#include <vespa/storage/distributor/distributor.h>
+#include <vespa/storage/distributor/distributormetricsset.h>
+#include <vespa/storageapi/message/persistence.h>
#include <vespa/document/test/make_document_bucket.h>
using document::test::makeDocumentBucket;
-namespace storage {
-namespace distributor {
+namespace storage::distributor {
class ExternalOperationHandlerTest : public CppUnit::TestFixture,
public DistributorTestUtil
@@ -471,5 +470,4 @@ void ExternalOperationHandlerTest::sequencing_can_be_explicitly_config_disabled(
// pseudo-locks in the sequencer. I.e. if we get a RemoveLocation with id.user==123456, this
// prevents any handles from being acquired to any GID under location BucketId(32, 123456).
-} // distributor
-} // storage
+}
diff --git a/storage/src/tests/distributor/getoperationtest.cpp b/storage/src/tests/distributor/getoperationtest.cpp
index 8bb8e24c17a..80c093dea87 100644
--- a/storage/src/tests/distributor/getoperationtest.cpp
+++ b/storage/src/tests/distributor/getoperationtest.cpp
@@ -5,16 +5,13 @@
#include <vespa/document/repo/documenttyperepo.h>
#include <vespa/storage/distributor/externaloperationhandler.h>
#include <vespa/storage/distributor/distributor.h>
+#include <vespa/storage/distributor/distributormetricsset.h>
#include <tests/distributor/distributortestutil.h>
#include <vespa/storageapi/message/persistence.h>
-#include <tests/common/dummystoragelink.h>
#include <vespa/document/test/make_document_bucket.h>
-#include <vespa/vdstestlib/cppunit/macros.h>
#include <vespa/vespalib/testkit/testapp.h>
#include <vespa/config/helper/configgetter.hpp>
#include <iomanip>
-#include <iostream>
-#include <memory>
#include <vespa/storage/distributor/operations/external/getoperation.h>
using std::shared_ptr;
@@ -23,8 +20,7 @@ using document::DocumenttypesConfig;
using config::FileSpec;
using document::test::makeDocumentBucket;
-namespace storage {
-namespace distributor {
+namespace storage::distributor {
class GetOperationTest : public CppUnit::TestFixture, public DistributorTestUtil {
CPPUNIT_TEST_SUITE(GetOperationTest);
@@ -568,5 +564,4 @@ GetOperationTest::canGetDocumentsWhenAllReplicaNodesRetired()
_sender.getCommands(true));
}
-} // distributor
-} // storage
+}
diff --git a/storage/src/tests/distributor/idealstatemanagertest.cpp b/storage/src/tests/distributor/idealstatemanagertest.cpp
index 0c695f9a3d4..bca15d702f5 100644
--- a/storage/src/tests/distributor/idealstatemanagertest.cpp
+++ b/storage/src/tests/distributor/idealstatemanagertest.cpp
@@ -143,9 +143,9 @@ IdealStateManagerTest::testClearActiveOnNodeDown()
}
CPPUNIT_ASSERT_EQUAL(
- std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)) (pri 100)\n"
- "setbucketstate to [0] Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000002)) (pri 100)\n"
- "setbucketstate to [0] Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000003)) (pri 100)\n"),
+ std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)) (pri 100)\n"
+ "setbucketstate to [0] Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000002)) (pri 100)\n"
+ "setbucketstate to [0] Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000003)) (pri 100)\n"),
_distributor->getActiveIdealStateOperations());
setSystemState(lib::ClusterState("distributor:1 storage:3 .0.s:d"));
@@ -169,19 +169,19 @@ IdealStateManagerTest::testRecheckWhenActive()
tick();
CPPUNIT_ASSERT_EQUAL(
- std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)) (pri 100)\n"),
+ std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)) (pri 100)\n"),
_distributor->getActiveIdealStateOperations());
tick();
CPPUNIT_ASSERT_EQUAL(
- std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)) (pri 100)\n"),
+ std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)) (pri 100)\n"),
_distributor->getActiveIdealStateOperations());
tick();
CPPUNIT_ASSERT_EQUAL(
- std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)) (pri 100)\n"),
+ std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)) (pri 100)\n"),
_distributor->getActiveIdealStateOperations());
}
diff --git a/storage/src/tests/distributor/maintenanceschedulertest.cpp b/storage/src/tests/distributor/maintenanceschedulertest.cpp
index 7e3d92053f8..db0347617f0 100644
--- a/storage/src/tests/distributor/maintenanceschedulertest.cpp
+++ b/storage/src/tests/distributor/maintenanceschedulertest.cpp
@@ -70,7 +70,7 @@ MaintenanceSchedulerTest::testOperationIsScheduled()
{
_priorityDb->setPriority(PrioritizedBucket(makeDocumentBucket(BucketId(16, 1)), Priority::MEDIUM));
_scheduler->tick(MaintenanceScheduler::NORMAL_SCHEDULING_MODE);
- CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri 100\n"),
+ CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri 100\n"),
_operationStarter->toString());
}
@@ -89,9 +89,9 @@ MaintenanceSchedulerTest::testSuppressLowPrioritiesInEmergencyMode()
_priorityDb->setPriority(PrioritizedBucket(makeDocumentBucket(BucketId(16, 2)), Priority::VERY_HIGH));
CPPUNIT_ASSERT_EQUAL(WaitTimeMs(0), _scheduler->tick(MaintenanceScheduler::RECOVERY_SCHEDULING_MODE));
CPPUNIT_ASSERT_EQUAL(WaitTimeMs(1), _scheduler->tick(MaintenanceScheduler::RECOVERY_SCHEDULING_MODE));
- CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000002)), pri 0\n"),
+ CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000002)), pri 0\n"),
_operationStarter->toString());
- CPPUNIT_ASSERT_EQUAL(std::string("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri HIGH)\n"),
+ CPPUNIT_ASSERT_EQUAL(std::string("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri HIGH)\n"),
_priorityDb->toString());
}
@@ -102,7 +102,7 @@ MaintenanceSchedulerTest::testPriorityNotClearedIfOperationNotStarted()
_operationStarter->setShouldStartOperations(false);
WaitTimeMs waitMs(_scheduler->tick(MaintenanceScheduler::NORMAL_SCHEDULING_MODE));
CPPUNIT_ASSERT_EQUAL(WaitTimeMs(1), waitMs);
- CPPUNIT_ASSERT_EQUAL(std::string("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri HIGH)\n"),
+ CPPUNIT_ASSERT_EQUAL(std::string("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri HIGH)\n"),
_priorityDb->toString());
}
diff --git a/storage/src/tests/distributor/messagesenderstub.h b/storage/src/tests/distributor/messagesenderstub.h
index e5bae9c6702..b86863890a1 100644
--- a/storage/src/tests/distributor/messagesenderstub.h
+++ b/storage/src/tests/distributor/messagesenderstub.h
@@ -3,6 +3,7 @@
#include <vespa/storage/distributor/distributormessagesender.h>
#include <cassert>
+#include <vector>
namespace storage {
diff --git a/storage/src/tests/distributor/pendingmessagetrackertest.cpp b/storage/src/tests/distributor/pendingmessagetrackertest.cpp
index cca55a11b38..7adadd226d7 100644
--- a/storage/src/tests/distributor/pendingmessagetrackertest.cpp
+++ b/storage/src/tests/distributor/pendingmessagetrackertest.cpp
@@ -254,7 +254,7 @@ PendingMessageTrackerTest::testSimple()
CPPUNIT_ASSERT_CONTAIN(
std::string(
- "<b>Bucket(BucketSpace(0x0000000000000000), BucketId(0x40000000000004d2))</b>\n"
+ "<b>Bucket(BucketSpace(0x0000000000000001), BucketId(0x40000000000004d2))</b>\n"
"<ul>\n"
"<li><i>Node 0</i>: <b>1970-01-01 00:00:01</b> "
"Remove(BucketId(0x40000000000004d2), "
@@ -341,14 +341,14 @@ PendingMessageTrackerTest::testMultipleMessages()
CPPUNIT_ASSERT_CONTAIN(
std::string(
- "<b>Bucket(BucketSpace(0x0000000000000000), BucketId(0x40000000000004d2))</b>\n"
+ "<b>Bucket(BucketSpace(0x0000000000000001), BucketId(0x40000000000004d2))</b>\n"
"<ul>\n"
"<li><i>Node 0</i>: <b>1970-01-01 00:00:01</b> Remove(BucketId(0x40000000000004d2), userdoc:footype:1234:0, timestamp 1000)</li>\n"
"<li><i>Node 0</i>: <b>1970-01-01 00:00:01</b> Remove(BucketId(0x40000000000004d2), userdoc:footype:1234:2, timestamp 1002)</li>\n"
"<li><i>Node 1</i>: <b>1970-01-01 00:00:01</b> Remove(BucketId(0x40000000000004d2), userdoc:footype:1234:1, timestamp 1001)</li>\n"
"<li><i>Node 1</i>: <b>1970-01-01 00:00:01</b> Remove(BucketId(0x40000000000004d2), userdoc:footype:1234:3, timestamp 1003)</li>\n"
"</ul>\n"
- "<b>Bucket(BucketSpace(0x0000000000000000), BucketId(0x40000000000011d7))</b>\n"
+ "<b>Bucket(BucketSpace(0x0000000000000001), BucketId(0x40000000000011d7))</b>\n"
"<ul>\n"
"<li><i>Node 0</i>: <b>1970-01-01 00:00:01</b> Remove(BucketId(0x40000000000011d7), userdoc:footype:4567:0, timestamp 2000)</li>\n"
"<li><i>Node 0</i>: <b>1970-01-01 00:00:01</b> Remove(BucketId(0x40000000000011d7), userdoc:footype:4567:2, timestamp 2002)</li>\n"
diff --git a/storage/src/tests/distributor/putoperationtest.cpp b/storage/src/tests/distributor/putoperationtest.cpp
index 7f54e163006..e621ef8645c 100644
--- a/storage/src/tests/distributor/putoperationtest.cpp
+++ b/storage/src/tests/distributor/putoperationtest.cpp
@@ -282,7 +282,7 @@ PutOperationTest::testNodeRemovedOnReply()
CPPUNIT_ASSERT_EQUAL(std::string(
"PutReply(doc:test:test, BucketId(0x0000000000000000), "
"timestamp 100) ReturnCode(BUCKET_DELETED, "
- "Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000008b13)) was deleted from nodes [0] "
+ "Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000008b13)) was deleted from nodes [0] "
"after message was sent but before it was done. "
"Sent to [1,0])"),
_sender.getLastReply());
diff --git a/storage/src/tests/distributor/simplemaintenancescannertest.cpp b/storage/src/tests/distributor/simplemaintenancescannertest.cpp
index 66a2d3efa6c..394df6024fd 100644
--- a/storage/src/tests/distributor/simplemaintenancescannertest.cpp
+++ b/storage/src/tests/distributor/simplemaintenancescannertest.cpp
@@ -92,7 +92,7 @@ void
SimpleMaintenanceScannerTest::testPrioritizeSingleBucket()
{
addBucketToDb(1);
- std::string expected("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri VERY_HIGH)\n");
+ std::string expected("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri VERY_HIGH)\n");
auto scanResult = _scanner->scanNext();
CPPUNIT_ASSERT(!scanResult.isDone());
@@ -141,9 +141,9 @@ SimpleMaintenanceScannerTest::testPrioritizeMultipleBuckets()
addBucketToDb(1);
addBucketToDb(2);
addBucketToDb(3);
- std::string expected("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri VERY_HIGH)\n"
- "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000002)), pri VERY_HIGH)\n"
- "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000003)), pri VERY_HIGH)\n");
+ std::string expected("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri VERY_HIGH)\n"
+ "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000002)), pri VERY_HIGH)\n"
+ "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000003)), pri VERY_HIGH)\n");
CPPUNIT_ASSERT(scanEntireDatabase(3));
CPPUNIT_ASSERT_EQUAL(sortLines(expected),
@@ -168,8 +168,8 @@ SimpleMaintenanceScannerTest::testReset()
addBucketToDb(3);
CPPUNIT_ASSERT(scanEntireDatabase(2));
- std::string expected("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri VERY_HIGH)\n"
- "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000003)), pri VERY_HIGH)\n");
+ std::string expected("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri VERY_HIGH)\n"
+ "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000003)), pri VERY_HIGH)\n");
CPPUNIT_ASSERT_EQUAL(expected, _priorityDb->toString());
addBucketToDb(2);
@@ -179,9 +179,9 @@ SimpleMaintenanceScannerTest::testReset()
_scanner->reset();
CPPUNIT_ASSERT(scanEntireDatabase(3));
- expected = "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri VERY_HIGH)\n"
- "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000002)), pri VERY_HIGH)\n"
- "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000003)), pri VERY_HIGH)\n";
+ expected = "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri VERY_HIGH)\n"
+ "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000002)), pri VERY_HIGH)\n"
+ "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000003)), pri VERY_HIGH)\n";
CPPUNIT_ASSERT_EQUAL(sortLines(expected), sortLines(_priorityDb->toString()));
}
diff --git a/storage/src/tests/distributor/throttlingoperationstartertest.cpp b/storage/src/tests/distributor/throttlingoperationstartertest.cpp
index c3aebcafe06..c3290a8c0f6 100644
--- a/storage/src/tests/distributor/throttlingoperationstartertest.cpp
+++ b/storage/src/tests/distributor/throttlingoperationstartertest.cpp
@@ -70,7 +70,7 @@ ThrottlingOperationStarterTest::testOperationStartingIsForwardedToImplementation
{
CPPUNIT_ASSERT(_operationStarter->start(createMockOperation(),
OperationStarter::Priority(0)));
- CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri 0\n"),
+ CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri 0\n"),
_starterImpl->toString());
}
diff --git a/storage/src/tests/distributor/visitoroperationtest.cpp b/storage/src/tests/distributor/visitoroperationtest.cpp
index 972ccf41bfe..17d1bc288ca 100644
--- a/storage/src/tests/distributor/visitoroperationtest.cpp
+++ b/storage/src/tests/distributor/visitoroperationtest.cpp
@@ -9,6 +9,7 @@
#include <vespa/storageapi/message/state.h>
#include <vespa/storage/distributor/operations/external/visitoroperation.h>
#include <vespa/storage/distributor/operations/external/visitororder.h>
+#include <vespa/storage/distributor/distributormetricsset.h>
#include <tests/distributor/distributortestutil.h>
#include <vespa/storage/distributor/distributor.h>
#include <tests/common/dummystoragelink.h>
@@ -21,8 +22,7 @@ using namespace storage::lib;
using namespace std::string_literals;
using document::test::makeBucketSpace;
-namespace storage {
-namespace distributor {
+namespace storage::distributor {
class VisitorOperationTest : public CppUnit::TestFixture,
public DistributorTestUtil {
@@ -1674,5 +1674,4 @@ VisitorOperationTest::statistical_metrics_not_updated_on_wrong_distribution()
CPPUNIT_ASSERT_EQUAL(0.0, defaultVisitorMetrics().latency.getCount());
}
-} // distributor
-} // storage
+}
diff --git a/storage/src/tests/persistence/splitbitdetectortest.cpp b/storage/src/tests/persistence/splitbitdetectortest.cpp
index c20aae373ec..01baa8f4e98 100644
--- a/storage/src/tests/persistence/splitbitdetectortest.cpp
+++ b/storage/src/tests/persistence/splitbitdetectortest.cpp
@@ -8,6 +8,7 @@
#include <vespa/persistence/spi/test.h>
#include <vespa/document/base/testdocman.h>
#include <vespa/document/bucket/bucketidfactory.h>
+#include <vespa/metrics/loadmetric.h>
#include <algorithm>
using storage::spi::test::makeSpiBucket;
diff --git a/storage/src/tests/storageserver/CMakeLists.txt b/storage/src/tests/storageserver/CMakeLists.txt
index 38fb0f6235a..95faf7e433e 100644
--- a/storage/src/tests/storageserver/CMakeLists.txt
+++ b/storage/src/tests/storageserver/CMakeLists.txt
@@ -5,6 +5,7 @@ vespa_add_library(storage_teststorageserver TEST
bucketintegritycheckertest.cpp
changedbucketownershiphandlertest.cpp
communicationmanagertest.cpp
+ configurable_bucket_resolver_test.cpp
documentapiconvertertest.cpp
mergethrottlertest.cpp
priorityconvertertest.cpp
diff --git a/storage/src/tests/storageserver/configurable_bucket_resolver_test.cpp b/storage/src/tests/storageserver/configurable_bucket_resolver_test.cpp
new file mode 100644
index 00000000000..3f121240065
--- /dev/null
+++ b/storage/src/tests/storageserver/configurable_bucket_resolver_test.cpp
@@ -0,0 +1,137 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/storage/storageserver/configurable_bucket_resolver.h>
+#include <vespa/document/base/documentid.h>
+#include <vespa/persistence/spi/fixed_bucket_spaces.h>
+#include <cppunit/extensions/HelperMacros.h>
+
+namespace storage {
+
+using document::DocumentId;
+
+struct ConfigurableBucketResolverTest : CppUnit::TestFixture {
+ CPPUNIT_TEST_SUITE(ConfigurableBucketResolverTest);
+ CPPUNIT_TEST(bucket_space_from_name_is_defined_for_default_space);
+ CPPUNIT_TEST(bucket_space_from_name_is_defined_for_global_space);
+ CPPUNIT_TEST(bucket_space_from_name_throws_exception_for_unknown_space);
+ CPPUNIT_TEST(name_from_bucket_space_is_defined_for_default_space);
+ CPPUNIT_TEST(name_from_bucket_space_is_defined_for_global_space);
+ CPPUNIT_TEST(name_from_bucket_space_throws_exception_for_unknown_space);
+ CPPUNIT_TEST(known_bucket_space_is_resolved_from_document_id);
+ CPPUNIT_TEST(unknown_bucket_space_in_id_throws_exception);
+ CPPUNIT_TEST(can_create_resolver_from_bucket_space_config);
+ CPPUNIT_TEST_SUITE_END();
+
+ using BucketSpaceMapping = ConfigurableBucketResolver::BucketSpaceMapping;
+
+ BucketSpaceMapping create_simple_mapping() {
+ return {{"foo", spi::FixedBucketSpaces::default_space()},
+ {"bar", spi::FixedBucketSpaces::default_space()},
+ {"baz", spi::FixedBucketSpaces::global_space()}};
+ }
+
+ ConfigurableBucketResolver create_empty_resolver() {
+ return ConfigurableBucketResolver({});
+ }
+
+ ConfigurableBucketResolver create_simple_resolver() {
+ return ConfigurableBucketResolver(create_simple_mapping());
+ }
+
+ void bucket_space_from_name_is_defined_for_default_space();
+ void bucket_space_from_name_is_defined_for_global_space();
+ void bucket_space_from_name_throws_exception_for_unknown_space();
+ void name_from_bucket_space_is_defined_for_default_space();
+ void name_from_bucket_space_is_defined_for_global_space();
+ void name_from_bucket_space_throws_exception_for_unknown_space();
+ void known_bucket_space_is_resolved_from_document_id();
+ void unknown_bucket_space_in_id_throws_exception();
+ void can_create_resolver_from_bucket_space_config();
+};
+
+CPPUNIT_TEST_SUITE_REGISTRATION(ConfigurableBucketResolverTest);
+
+// TODO reduce overlap with FixedBucketSpacesTest
+void ConfigurableBucketResolverTest::bucket_space_from_name_is_defined_for_default_space() {
+ auto space = create_empty_resolver().bucketSpaceFromName("default");
+ CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::default_space(), space);
+}
+
+void ConfigurableBucketResolverTest::bucket_space_from_name_is_defined_for_global_space() {
+ auto space = create_empty_resolver().bucketSpaceFromName("global");
+ CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::global_space(), space);
+}
+
+void ConfigurableBucketResolverTest::bucket_space_from_name_throws_exception_for_unknown_space() {
+ try {
+ create_empty_resolver().bucketSpaceFromName("bjarne");
+ CPPUNIT_FAIL("Expected exception on unknown bucket space name");
+ } catch (spi::UnknownBucketSpaceException& e) {
+ }
+}
+
+void ConfigurableBucketResolverTest::name_from_bucket_space_is_defined_for_default_space() {
+ CPPUNIT_ASSERT_EQUAL(vespalib::string("default"),
+ create_empty_resolver().nameFromBucketSpace(spi::FixedBucketSpaces::default_space()));
+}
+
+void ConfigurableBucketResolverTest::name_from_bucket_space_is_defined_for_global_space() {
+ CPPUNIT_ASSERT_EQUAL(vespalib::string("global"),
+ create_empty_resolver().nameFromBucketSpace(spi::FixedBucketSpaces::global_space()));
+}
+
+void ConfigurableBucketResolverTest::name_from_bucket_space_throws_exception_for_unknown_space() {
+ try {
+ create_empty_resolver().nameFromBucketSpace(document::BucketSpace(1234));
+ CPPUNIT_FAIL("Expected exception on unknown bucket space value");
+ } catch (spi::UnknownBucketSpaceException& e) {
+ }
+}
+
+void ConfigurableBucketResolverTest::known_bucket_space_is_resolved_from_document_id() {
+ auto resolver = create_simple_resolver();
+ CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::default_space(),
+ resolver.bucketFromId(DocumentId("id::foo::xyz")).getBucketSpace());
+ CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::default_space(),
+ resolver.bucketFromId(DocumentId("id::bar::xyz")).getBucketSpace());
+ CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::global_space(),
+ resolver.bucketFromId(DocumentId("id::baz::xyz")).getBucketSpace());
+}
+
+void ConfigurableBucketResolverTest::unknown_bucket_space_in_id_throws_exception() {
+ try {
+ create_simple_resolver().bucketFromId(DocumentId("id::bjarne::xyz"));
+ CPPUNIT_FAIL("Expected exception on unknown document type -> bucket space mapping");
+ } catch (spi::UnknownBucketSpaceException& e) {
+ }
+}
+
+using BucketSpacesConfigBuilder = vespa::config::content::core::BucketspacesConfigBuilder;
+
+namespace {
+
+BucketSpacesConfigBuilder::Documenttype make_doc_type(vespalib::stringref name, vespalib::stringref space) {
+ BucketSpacesConfigBuilder::Documenttype doc_type;
+ doc_type.name = name;
+ doc_type.bucketspace = space;
+ return doc_type;
+}
+
+}
+
+void ConfigurableBucketResolverTest::can_create_resolver_from_bucket_space_config() {
+ BucketSpacesConfigBuilder builder;
+ builder.documenttype.emplace_back(make_doc_type("foo", "default"));
+ builder.documenttype.emplace_back(make_doc_type("bar", "global"));
+ builder.documenttype.emplace_back(make_doc_type("baz", "global"));
+ auto resolver = ConfigurableBucketResolver::from_config(builder);
+ CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::default_space(),
+ resolver->bucketFromId(DocumentId("id::foo::xyz")).getBucketSpace());
+ CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::global_space(),
+ resolver->bucketFromId(DocumentId("id::bar::xyz")).getBucketSpace());
+ CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::global_space(),
+ resolver->bucketFromId(DocumentId("id::baz::xyz")).getBucketSpace());
+}
+
+}
+
diff --git a/storage/src/tests/storageserver/documentapiconvertertest.cpp b/storage/src/tests/storageserver/documentapiconvertertest.cpp
index 386be60d88c..b878d5f6719 100644
--- a/storage/src/tests/storageserver/documentapiconvertertest.cpp
+++ b/storage/src/tests/storageserver/documentapiconvertertest.cpp
@@ -60,13 +60,13 @@ struct MockBucketResolver : public BucketResolver {
struct DocumentApiConverterTest : public CppUnit::TestFixture
{
- MockBucketResolver _bucketResolver;
+ std::shared_ptr<MockBucketResolver> _bucketResolver;
std::unique_ptr<DocumentApiConverter> _converter;
const DocumentTypeRepo::SP _repo;
const DataType& _html_type;
DocumentApiConverterTest()
- : _bucketResolver(),
+ : _bucketResolver(std::make_shared<MockBucketResolver>()),
_repo(std::make_shared<DocumentTypeRepo>(readDocumenttypesConfig(
TEST_PATH("config-doctypes.cfg")))),
_html_type(*_repo->getDocumentType("text/html"))
@@ -120,6 +120,7 @@ struct DocumentApiConverterTest : public CppUnit::TestFixture
void testStatBucket();
void testGetBucketList();
void testRemoveLocation();
+ void can_replace_bucket_resolver_after_construction();
CPPUNIT_TEST_SUITE(DocumentApiConverterTest);
CPPUNIT_TEST(testPut);
@@ -138,6 +139,7 @@ struct DocumentApiConverterTest : public CppUnit::TestFixture
CPPUNIT_TEST(testStatBucket);
CPPUNIT_TEST(testGetBucketList);
CPPUNIT_TEST(testRemoveLocation);
+ CPPUNIT_TEST(can_replace_bucket_resolver_after_construction);
CPPUNIT_TEST_SUITE_END();
};
@@ -463,4 +465,29 @@ DocumentApiConverterTest::testRemoveLocation()
CPPUNIT_ASSERT_EQUAL(defaultBucket, cmd->getBucket());
}
+namespace {
+
+struct ReplacementMockBucketResolver : public MockBucketResolver {
+ Bucket bucketFromId(const DocumentId& id) const override {
+ if (id.getDocType() == "testdoctype1") {
+ return defaultBucket;
+ }
+ return Bucket(BucketSpace(0), BucketId(0));
+ }
+};
+
+}
+
+void DocumentApiConverterTest::can_replace_bucket_resolver_after_construction() {
+ documentapi::GetDocumentMessage get_msg(DocumentId("id::testdoctype1::baz"), "foo bar");
+ auto cmd = toStorageAPI<api::GetCommand>(get_msg);
+
+ CPPUNIT_ASSERT_EQUAL(BucketSpace(0), cmd->getBucket().getBucketSpace());
+
+ _converter->setBucketResolver(std::make_shared<ReplacementMockBucketResolver>());
+
+ cmd = toStorageAPI<api::GetCommand>(get_msg);
+ CPPUNIT_ASSERT_EQUAL(defaultBucketSpace, cmd->getBucket().getBucketSpace());
+}
+
}
diff --git a/storage/src/vespa/storage/common/bucketmessages.cpp b/storage/src/vespa/storage/common/bucketmessages.cpp
index 3157bad49e5..e92e2d4c3bf 100644
--- a/storage/src/vespa/storage/common/bucketmessages.cpp
+++ b/storage/src/vespa/storage/common/bucketmessages.cpp
@@ -2,6 +2,7 @@
#include "bucketmessages.h"
#include <vespa/vespalib/stllike/asciistream.h>
+#include <ostream>
using document::BucketSpace;
diff --git a/storage/src/vespa/storage/common/messagesender.h b/storage/src/vespa/storage/common/messagesender.h
index 8c45995c42f..659fccad412 100644
--- a/storage/src/vespa/storage/common/messagesender.h
+++ b/storage/src/vespa/storage/common/messagesender.h
@@ -18,13 +18,14 @@
#include <memory>
-namespace storage {
-namespace api {
+namespace storage::api {
class StorageCommand;
class StorageReply;
class StorageMessage;
}
+namespace storage {
+
struct MessageSender {
virtual ~MessageSender() {}
diff --git a/storage/src/vespa/storage/common/storagecomponent.cpp b/storage/src/vespa/storage/common/storagecomponent.cpp
index bf387240dc5..1d6b563f6eb 100644
--- a/storage/src/vespa/storage/common/storagecomponent.cpp
+++ b/storage/src/vespa/storage/common/storagecomponent.cpp
@@ -28,14 +28,14 @@ StorageComponent::setNodeInfo(vespalib::stringref clusterName,
void
StorageComponent::setDocumentTypeRepo(DocumentTypeRepoSP repo)
{
- std::lock_guard<std::mutex> guard(_lock);
+ std::lock_guard guard(_lock);
_docTypeRepo = repo;
}
void
StorageComponent::setLoadTypes(LoadTypeSetSP loadTypes)
{
- std::lock_guard<std::mutex> guard(_lock);
+ std::lock_guard guard(_lock);
_loadTypes = loadTypes;
}
@@ -57,14 +57,21 @@ StorageComponent::setBucketIdFactory(const document::BucketIdFactory& factory)
void
StorageComponent::setDistribution(DistributionSP distribution)
{
- std::lock_guard<std::mutex> guard(_lock);
+ std::lock_guard guard(_lock);
_distribution = distribution;
}
void
+StorageComponent::enableMultipleBucketSpaces(bool value)
+{
+ std::lock_guard guard(_lock);
+ _enableMultipleBucketSpaces = value;
+}
+
+void
StorageComponent::setNodeStateUpdater(NodeStateUpdater& updater)
{
- std::lock_guard<std::mutex> guard(_lock);
+ std::lock_guard guard(_lock);
if (_nodeStateUpdater != 0) {
throw vespalib::IllegalStateException(
"Node state updater is already set", VESPA_STRLOC);
@@ -76,10 +83,16 @@ StorageComponent::StorageComponent(StorageComponentRegister& compReg,
vespalib::stringref name)
: Component(compReg, name),
_clusterName(),
- _nodeType(0),
+ _nodeType(nullptr),
_index(0),
+ _docTypeRepo(),
+ _loadTypes(),
_priorityMapper(new PriorityMapper),
- _nodeStateUpdater(0)
+ _bucketIdFactory(),
+ _distribution(),
+ _nodeStateUpdater(nullptr),
+ _lock(),
+ _enableMultipleBucketSpaces(false)
{
compReg.registerStorageComponent(*this);
}
@@ -87,7 +100,7 @@ StorageComponent::StorageComponent(StorageComponentRegister& compReg,
NodeStateUpdater&
StorageComponent::getStateUpdater() const
{
- std::lock_guard<std::mutex> guard(_lock);
+ std::lock_guard guard(_lock);
if (_nodeStateUpdater == 0) {
throw vespalib::IllegalStateException(
"Component need node state updater at this time, but it has "
@@ -114,22 +127,29 @@ StorageComponent::getPriority(const documentapi::LoadType& lt) const
StorageComponent::DocumentTypeRepoSP
StorageComponent::getTypeRepo() const
{
- std::lock_guard<std::mutex> guard(_lock);
+ std::lock_guard guard(_lock);
return _docTypeRepo;
}
StorageComponent::LoadTypeSetSP
StorageComponent::getLoadTypes() const
{
- std::lock_guard<std::mutex> guard(_lock);
+ std::lock_guard guard(_lock);
return _loadTypes;
}
StorageComponent::DistributionSP
StorageComponent::getDistribution() const
{
- std::lock_guard<std::mutex> guard(_lock);
+ std::lock_guard guard(_lock);
return _distribution;
}
+bool
+StorageComponent::enableMultipleBucketSpaces() const
+{
+ std::lock_guard guard(_lock);
+ return _enableMultipleBucketSpaces;
+}
+
} // storage
diff --git a/storage/src/vespa/storage/common/storagecomponent.h b/storage/src/vespa/storage/common/storagecomponent.h
index d469540b55f..e136d991ac5 100644
--- a/storage/src/vespa/storage/common/storagecomponent.h
+++ b/storage/src/vespa/storage/common/storagecomponent.h
@@ -37,10 +37,9 @@
#include <vespa/vdslib/state/node.h>
#include <mutex>
-namespace vespa { namespace config { namespace content { namespace core {
-namespace internal {
+namespace vespa::config::content::core::internal {
class InternalStorPrioritymappingType;
-} } } } }
+}
namespace document {
class DocumentTypeRepo;
}
@@ -59,11 +58,11 @@ class StorageComponentRegister;
class StorageComponent : public framework::Component {
public:
- typedef std::unique_ptr<StorageComponent> UP;
- typedef vespa::config::content::core::internal::InternalStorPrioritymappingType PriorityConfig;
- typedef std::shared_ptr<document::DocumentTypeRepo> DocumentTypeRepoSP;
- typedef std::shared_ptr<documentapi::LoadTypeSet> LoadTypeSetSP;
- typedef std::shared_ptr<lib::Distribution> DistributionSP;
+ using UP = std::unique_ptr<StorageComponent>;
+ using PriorityConfig = vespa::config::content::core::internal::InternalStorPrioritymappingType;
+ using DocumentTypeRepoSP = std::shared_ptr<document::DocumentTypeRepo>;
+ using LoadTypeSetSP = std::shared_ptr<documentapi::LoadTypeSet>;
+ using DistributionSP = std::shared_ptr<lib::Distribution>;
/**
* Node type is supposed to be set immediately, and never be updated.
@@ -84,6 +83,7 @@ public:
void setPriorityConfig(const PriorityConfig&);
void setBucketIdFactory(const document::BucketIdFactory&);
void setDistribution(DistributionSP);
+ void enableMultipleBucketSpaces(bool value);
StorageComponent(StorageComponentRegister&, vespalib::stringref name);
virtual ~StorageComponent();
@@ -102,6 +102,7 @@ public:
uint8_t getPriority(const documentapi::LoadType&) const;
DistributionSP getDistribution() const;
NodeStateUpdater& getStateUpdater() const;
+ bool enableMultipleBucketSpaces() const;
private:
vespalib::string _clusterName;
@@ -114,6 +115,7 @@ private:
DistributionSP _distribution;
NodeStateUpdater* _nodeStateUpdater;
mutable std::mutex _lock;
+ bool _enableMultipleBucketSpaces;
};
struct StorageComponentRegister : public virtual framework::ComponentRegister
diff --git a/storage/src/vespa/storage/distributor/bucketdbupdater.cpp b/storage/src/vespa/storage/distributor/bucketdbupdater.cpp
index 46fa0f72d76..cc1181e0d58 100644
--- a/storage/src/vespa/storage/distributor/bucketdbupdater.cpp
+++ b/storage/src/vespa/storage/distributor/bucketdbupdater.cpp
@@ -2,9 +2,9 @@
#include "bucketdbupdater.h"
#include "distributor.h"
-#include "distributor_bucket_space_repo.h"
#include "distributor_bucket_space.h"
#include "simpleclusterinformation.h"
+#include "distributormetricsset.h"
#include <vespa/storage/common/bucketoperationlogger.h>
#include <vespa/storageapi/message/persistence.h>
#include <vespa/storageapi/message/removelocation.h>
diff --git a/storage/src/vespa/storage/distributor/bucketdbupdater.h b/storage/src/vespa/storage/distributor/bucketdbupdater.h
index 29e8d3f6221..19e2e259778 100644
--- a/storage/src/vespa/storage/distributor/bucketdbupdater.h
+++ b/storage/src/vespa/storage/distributor/bucketdbupdater.h
@@ -13,9 +13,8 @@
#include <vespa/vdslib/state/clusterstate.h>
#include <vespa/storage/common/storagelink.h>
#include <vespa/storageframework/generic/clock/timer.h>
+#include <vespa/storageframework/generic/status/statusreporter.h>
#include <vespa/storageapi/messageapi/messagehandler.h>
-#include <set>
-#include <deque>
#include <list>
namespace storage::distributor {
diff --git a/storage/src/vespa/storage/distributor/bucketgctimecalculator.h b/storage/src/vespa/storage/distributor/bucketgctimecalculator.h
index e2b232a6cf5..4ff85e568c8 100644
--- a/storage/src/vespa/storage/distributor/bucketgctimecalculator.h
+++ b/storage/src/vespa/storage/distributor/bucketgctimecalculator.h
@@ -4,8 +4,7 @@
#include <chrono>
#include <vespa/document/bucket/bucketid.h>
-namespace storage {
-namespace distributor {
+namespace storage::distributor {
/**
* Semantics are basically as follows:
@@ -51,6 +50,4 @@ private:
std::chrono::seconds _checkInterval;
};
-} // distributor
-} // storage
-
+}
diff --git a/storage/src/vespa/storage/distributor/bucketownership.h b/storage/src/vespa/storage/distributor/bucketownership.h
index c7a7773686f..bfe63c9799d 100644
--- a/storage/src/vespa/storage/distributor/bucketownership.h
+++ b/storage/src/vespa/storage/distributor/bucketownership.h
@@ -2,9 +2,9 @@
#pragma once
#include <vespa/vdslib/state/clusterstate.h>
+#include <cassert>
-namespace storage {
-namespace distributor {
+namespace storage::distributor {
class BucketOwnership
{
@@ -14,8 +14,7 @@ class BucketOwnership
BucketOwnership(const lib::ClusterState& checkedState)
: _checkedState(&checkedState),
_owned(false)
- {
- }
+ { }
BucketOwnership() : _checkedState(nullptr), _owned(true) {}
@@ -44,6 +43,4 @@ public:
}
};
-} // distributor
-} // storage
-
+}
diff --git a/storage/src/vespa/storage/distributor/distributor.cpp b/storage/src/vespa/storage/distributor/distributor.cpp
index 1edcbe75dd6..988d39e571d 100644
--- a/storage/src/vespa/storage/distributor/distributor.cpp
+++ b/storage/src/vespa/storage/distributor/distributor.cpp
@@ -5,21 +5,17 @@
#include "throttlingoperationstarter.h"
#include "idealstatemetricsset.h"
#include "ownership_transfer_safe_time_point_calculator.h"
-#include "distributor_bucket_space_repo.h"
#include "distributor_bucket_space.h"
-#include <vespa/storage/bucketdb/mapbucketdatabase.h>
-#include <vespa/storage/distributor/maintenance/simplemaintenancescanner.h>
+#include "distributormetricsset.h"
#include <vespa/storage/distributor/maintenance/simplebucketprioritydatabase.h>
#include <vespa/storage/common/nodestateupdater.h>
#include <vespa/storage/common/hostreporter/hostinfo.h>
#include <vespa/storageframework/generic/status/xmlstatusreporter.h>
-
#include <vespa/log/log.h>
LOG_SETUP(".distributor-main");
-namespace storage {
-namespace distributor {
+namespace storage::distributor {
class Distributor::Status {
const DelegatedStatusRequest& _request;
@@ -68,34 +64,25 @@ Distributor::Distributor(DistributorComponentRegister& compReg,
_compReg(compReg),
_component(compReg, "distributor"),
_bucketSpaceRepo(std::make_unique<DistributorBucketSpaceRepo>()),
- _metrics(new DistributorMetricSet(
- _component.getLoadTypes()->getMetricLoadTypes())),
+ _metrics(new DistributorMetricSet(_component.getLoadTypes()->getMetricLoadTypes())),
_operationOwner(*this, _component.getClock()),
_maintenanceOperationOwner(*this, _component.getClock()),
_pendingMessageTracker(compReg),
_bucketDBUpdater(*this, *_bucketSpaceRepo, *this, compReg),
_distributorStatusDelegate(compReg, *this, *this),
_bucketDBStatusDelegate(compReg, *this, _bucketDBUpdater),
- _idealStateManager(*this, *_bucketSpaceRepo, compReg,
- manageActiveBucketCopies),
- _externalOperationHandler(*this, *_bucketSpaceRepo,
- _idealStateManager, compReg),
+ _idealStateManager(*this, *_bucketSpaceRepo, compReg, manageActiveBucketCopies),
+ _externalOperationHandler(*this, *_bucketSpaceRepo, _idealStateManager, compReg),
_threadPool(threadPool),
_initializingIsUp(true),
_doneInitializeHandler(doneInitHandler),
_doneInitializing(false),
_messageSender(messageSender),
_bucketPriorityDb(new SimpleBucketPriorityDatabase()),
- _scanner(new SimpleMaintenanceScanner(
- *_bucketPriorityDb, _idealStateManager,
- *_bucketSpaceRepo)),
- _throttlingStarter(new ThrottlingOperationStarter(
- _maintenanceOperationOwner)),
- _blockingStarter(new BlockingOperationStarter(_pendingMessageTracker,
- *_throttlingStarter)),
- _scheduler(new MaintenanceScheduler(_idealStateManager,
- *_bucketPriorityDb,
- *_blockingStarter)),
+ _scanner(new SimpleMaintenanceScanner(*_bucketPriorityDb, _idealStateManager, *_bucketSpaceRepo)),
+ _throttlingStarter(new ThrottlingOperationStarter(_maintenanceOperationOwner)),
+ _blockingStarter(new BlockingOperationStarter(_pendingMessageTracker, *_throttlingStarter)),
+ _scheduler(new MaintenanceScheduler(_idealStateManager, *_bucketPriorityDb, *_blockingStarter)),
_schedulingMode(MaintenanceScheduler::NORMAL_SCHEDULING_MODE),
_recoveryTimeStarted(_component.getClock()),
_tickResult(framework::ThreadWaitInfo::NO_MORE_CRITICAL_WORK_KNOWN),
@@ -105,8 +92,7 @@ Distributor::Distributor(DistributorComponentRegister& compReg,
_metricLock(),
_maintenanceStats(),
_bucketDbStats(),
- _hostInfoReporter(_pendingMessageTracker.getLatencyStatisticsProvider(),
- *this),
+ _hostInfoReporter(_pendingMessageTracker.getLatencyStatisticsProvider(), *this),
_ownershipSafeTimeCalc(
std::make_unique<OwnershipTransferSafeTimePointCalculator>(
std::chrono::seconds(0))) // Set by config later
@@ -162,10 +148,8 @@ void
Distributor::sendCommand(const std::shared_ptr<api::StorageCommand>& cmd)
{
if (cmd->getType() == api::MessageType::MERGEBUCKET) {
- api::MergeBucketCommand& merge(
- static_cast<api::MergeBucketCommand&>(*cmd));
- _idealStateManager.getMetrics().nodesPerMerge.addValue(
- merge.getNodes().size());
+ api::MergeBucketCommand& merge(static_cast<api::MergeBucketCommand&>(*cmd));
+ _idealStateManager.getMetrics().nodesPerMerge.addValue(merge.getNodes().size());
}
sendUp(cmd);
}
@@ -179,10 +163,8 @@ Distributor::sendReply(const std::shared_ptr<api::StorageReply>& reply)
void
Distributor::setNodeStateUp()
{
- NodeStateUpdater::Lock::SP lock(
- _component.getStateUpdater().grabStateChangeLock());
- lib::NodeState ns(
- *_component.getStateUpdater().getReportedNodeState());
+ NodeStateUpdater::Lock::SP lock(_component.getStateUpdater().grabStateChangeLock());
+ lib::NodeState ns(*_component.getStateUpdater().getReportedNodeState());
ns.setState(lib::State::UP);
_component.getStateUpdater().setReportedNodeState(ns);
}
@@ -832,5 +814,4 @@ Distributor::handleStatusRequest(const DelegatedStatusRequest& request) const
return true;
}
-} // distributor
-} // storage
+}
diff --git a/storage/src/vespa/storage/distributor/distributorinterface.h b/storage/src/vespa/storage/distributor/distributorinterface.h
index bf27dc432b6..3445397c17d 100644
--- a/storage/src/vespa/storage/distributor/distributorinterface.h
+++ b/storage/src/vespa/storage/distributor/distributorinterface.h
@@ -1,32 +1,28 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once
-#include <vespa/storage/common/distributorcomponent.h>
-#include <vespa/storage/common/messagesender.h>
-#include <vespa/storage/distributor/pendingmessagetracker.h>
-#include <vespa/storageapi/message/state.h>
+#include "bucketgctimecalculator.h"
+#include "distributormessagesender.h"
+#include "bucketownership.h"
#include <vespa/storage/bucketdb/bucketdatabase.h>
-#include <vespa/storage/distributor/bucketgctimecalculator.h>
-#include <vespa/storage/distributor/distributormetricsset.h>
-#include <vespa/storage/config/distributorconfiguration.h>
-#include <vespa/storage/distributor/distributormessagesender.h>
-#include <vespa/storage/distributor/bucketownership.h>
+#include <vespa/document/bucket/bucket.h>
+namespace storage::api { class MergeBucketReply; }
namespace storage {
+ class DistributorConfiguration;
+ class DistributorMetricSet;
+}
+namespace storage::distributor {
-namespace distributor {
+class PendingMessageTracker;
class DistributorInterface : public DistributorMessageSender
{
public:
virtual PendingMessageTracker& getPendingMessageTracker() = 0;
-
virtual DistributorMetricSet& getMetrics() = 0;
-
virtual void enableClusterState(const lib::ClusterState& state) = 0;
-
virtual BucketOwnership checkOwnershipInPendingState(const document::Bucket &bucket) const = 0;
-
virtual void notifyDistributionChangeEnabled() = 0;
/**
@@ -55,19 +51,11 @@ public:
* Returns true if the node is currently initializing.
*/
virtual bool initializing() const = 0;
-
virtual void handleCompletedMerge(const std::shared_ptr<api::MergeBucketReply>&) = 0;
-
virtual const char* getStorageNodeUpStates() const = 0;
-
virtual const DistributorConfiguration& getConfig() const = 0;
-
virtual ChainedMessageSender& getMessageSender() = 0;
-
virtual const BucketGcTimeCalculator::BucketIdHasher& getBucketIdHasher() const = 0;
};
}
-
-}
-
diff --git a/storage/src/vespa/storage/distributor/distributormessagesender.h b/storage/src/vespa/storage/distributor/distributormessagesender.h
index 0fccaad87e3..078762dd05c 100644
--- a/storage/src/vespa/storage/distributor/distributormessagesender.h
+++ b/storage/src/vespa/storage/distributor/distributormessagesender.h
@@ -2,11 +2,9 @@
#pragma once
#include <vespa/storage/common/messagesender.h>
-#include <vespa/vdslib/distribution/distribution.h>
-namespace storage {
-
-namespace distributor {
+namespace storage::lib { class NodeType; }
+namespace storage::distributor {
class PendingMessageTracker;
@@ -16,21 +14,12 @@ public:
Sends the storage command to the given node,
returns message id.
*/
- virtual uint64_t sendToNode(const lib::NodeType& nodeType,
- uint16_t node,
- const std::shared_ptr<api::StorageCommand>& cmd,
- bool useDocumentAPI = false);
+ virtual uint64_t sendToNode(const lib::NodeType& nodeType, uint16_t node,
+ const std::shared_ptr<api::StorageCommand>& cmd, bool useDocumentAPI = false);
virtual int getDistributorIndex() const = 0;
-
virtual const std::string& getClusterName() const = 0;
-
virtual const PendingMessageTracker& getPendingMessageTracker() const = 0;
};
-} // distributor
-
-} // storage
-
-
-
+}
diff --git a/storage/src/vespa/storage/distributor/idealstatemanager.cpp b/storage/src/vespa/storage/distributor/idealstatemanager.cpp
index 4ceeb387341..031e9946178 100644
--- a/storage/src/vespa/storage/distributor/idealstatemanager.cpp
+++ b/storage/src/vespa/storage/distributor/idealstatemanager.cpp
@@ -95,7 +95,7 @@ IdealStateManager::getEntryForPrimaryBucket(StateChecker::Context& c) const
{
for (uint32_t j = 0; j < c.entries.size(); ++j) {
BucketDatabase::Entry& e = c.entries[j];
- if (e.getBucketId() == c.getBucketId()) {
+ if (e.getBucketId() == c.getBucketId() && ! e->getNodes().empty()) {
return &e;
}
}
diff --git a/storage/src/vespa/storage/distributor/idealstatemanager.h b/storage/src/vespa/storage/distributor/idealstatemanager.h
index b9607b35d28..028c9cbb0b6 100644
--- a/storage/src/vespa/storage/distributor/idealstatemanager.h
+++ b/storage/src/vespa/storage/distributor/idealstatemanager.h
@@ -1,18 +1,14 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once
-#include <deque>
-#include <map>
-#include <set>
-#include <vespa/storage/distributor/distributorcomponent.h>
-#include <vespa/storage/distributor/statechecker.h>
+#include "distributorcomponent.h"
+#include "statechecker.h"
#include <vespa/storage/distributor/maintenance/maintenanceprioritygenerator.h>
#include <vespa/storage/distributor/maintenance/maintenanceoperationgenerator.h>
+#include <vespa/storageframework/generic/status/htmlstatusreporter.h>
#include <vespa/vdslib/state/clusterstate.h>
-#include <vector>
-namespace storage {
-namespace distributor {
+namespace storage::distributor {
class IdealStateMetricSet;
class IdealStateOperation;
@@ -116,8 +112,7 @@ private:
DistributorComponent _distributorComponent;
DistributorBucketSpaceRepo &_bucketSpaceRepo;
- std::vector<IdealStateOperation::SP> generateOperationsForBucket(
- StateChecker::Context& c) const;
+ std::vector<IdealStateOperation::SP> generateOperationsForBucket(StateChecker::Context& c) const;
bool iAmUp() const;
@@ -125,9 +120,9 @@ private:
// Stats tracker to use for all generateAll() calls to avoid having
// to create a new hash map for each single bucket processed.
NodeMaintenanceStatsTracker _statsTracker;
- const IdealStateManager& _ism;
- document::BucketSpace _bucketSpace;
- std::ostream& _out;
+ const IdealStateManager & _ism;
+ document::BucketSpace _bucketSpace;
+ std::ostream & _out;
public:
StatusBucketVisitor(const IdealStateManager& ism, document::BucketSpace bucketSpace, std::ostream& out)
: _statsTracker(), _ism(ism), _bucketSpace(bucketSpace), _out(out) {}
@@ -139,11 +134,8 @@ private:
};
friend class StatusBucketVisitor;
- void getBucketStatus(document::BucketSpace bucketSpace,
- const BucketDatabase::Entry& entry,
- NodeMaintenanceStatsTracker& statsTracker,
- std::ostream& out) const;
+ void getBucketStatus(document::BucketSpace bucketSpace, const BucketDatabase::Entry& entry,
+ NodeMaintenanceStatsTracker& statsTracker, std::ostream& out) const;
};
-} // distributor
-} // storage
+}
diff --git a/storage/src/vespa/storage/distributor/messagetracker.cpp b/storage/src/vespa/storage/distributor/messagetracker.cpp
index b844987e978..6568cec9a80 100644
--- a/storage/src/vespa/storage/distributor/messagetracker.cpp
+++ b/storage/src/vespa/storage/distributor/messagetracker.cpp
@@ -1,19 +1,19 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "messagetracker.h"
+#include <vespa/storageapi/messageapi/bucketcommand.h>
+#include <vespa/storageapi/messageapi/bucketreply.h>
#include <vespa/log/log.h>
LOG_SETUP(".messagetracker");
-namespace storage {
-
-namespace distributor {
+namespace storage::distributor {
MessageTracker::MessageTracker(const std::string& clusterName)
: _clusterName(clusterName)
{}
-MessageTracker::~MessageTracker() {}
+MessageTracker::~MessageTracker() = default;
void
MessageTracker::flushQueue(MessageSender& sender)
@@ -48,7 +48,4 @@ MessageTracker::finished()
return _sentMessages.empty();
}
-
-}
-
}
diff --git a/storage/src/vespa/storage/distributor/messagetracker.h b/storage/src/vespa/storage/distributor/messagetracker.h
index 63c0be1ca93..017979c16c0 100644
--- a/storage/src/vespa/storage/distributor/messagetracker.h
+++ b/storage/src/vespa/storage/distributor/messagetracker.h
@@ -1,10 +1,14 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once
-#include "distributormetricsset.h"
#include <vespa/storage/common/messagesender.h>
-#include <vespa/storageapi/messageapi/bucketcommand.h>
-#include <vespa/storageapi/messageapi/bucketreply.h>
+#include <vector>
+#include <map>
+
+namespace storage::api {
+ class BucketCommand;
+ class BucketReply;
+}
namespace storage::distributor {
diff --git a/storage/src/vespa/storage/distributor/operations/external/putoperation.h b/storage/src/vespa/storage/distributor/operations/external/putoperation.h
index 8beffe8b2c3..c27f2ee2266 100644
--- a/storage/src/vespa/storage/distributor/operations/external/putoperation.h
+++ b/storage/src/vespa/storage/distributor/operations/external/putoperation.h
@@ -5,22 +5,21 @@
#include <vespa/storage/distributor/operations/sequenced_operation.h>
#include <vespa/storageapi/messageapi/returncode.h>
#include <vespa/storage/distributor/persistencemessagetracker.h>
-#include <vespa/storage/distributor/operationtargetresolver.h>
namespace document {
class Document;
}
-namespace storage {
-namespace lib {
+namespace storage::lib {
class Distribution;
}
-namespace api {
+namespace storage::api {
class CreateBucketReply;
class PutCommand;
}
-namespace distributor {
+namespace storage::distributor {
class DistributorBucketSpace;
+class OperationTargetList;
class PutOperation : public SequencedOperation
{
@@ -78,5 +77,4 @@ private:
DistributorBucketSpace &_bucketSpace;
};
-} // distributor
-} // storage
+}
diff --git a/storage/src/vespa/storage/distributor/operations/external/statbucketoperation.h b/storage/src/vespa/storage/distributor/operations/external/statbucketoperation.h
index af448c2dd55..d40924e23f0 100644
--- a/storage/src/vespa/storage/distributor/operations/external/statbucketoperation.h
+++ b/storage/src/vespa/storage/distributor/operations/external/statbucketoperation.h
@@ -8,12 +8,11 @@
#pragma once
#include <vespa/storage/distributor/operations/operation.h>
+#include <map>
-namespace storage {
+namespace storage::api { class StatBucketCommand; }
-namespace api { class StatBucketCommand; }
-
-namespace distributor {
+namespace storage::distributor {
class DistributorComponent;
class DistributorBucketSpace;
@@ -21,9 +20,8 @@ class DistributorBucketSpace;
class StatBucketOperation : public Operation
{
public:
- StatBucketOperation(DistributorComponent& manager,
- DistributorBucketSpace &bucketSpace,
- const std::shared_ptr<api::StatBucketCommand> & cmd);
+ StatBucketOperation(DistributorComponent& manager, DistributorBucketSpace &bucketSpace,
+ const std::shared_ptr<api::StatBucketCommand> & cmd);
~StatBucketOperation();
const char* getName() const override { return "statBucket"; }
@@ -37,10 +35,8 @@ private:
std::shared_ptr<api::StatBucketCommand> _command;
- std::map<uint64_t, uint16_t> _sent;
+ std::map<uint64_t, uint16_t> _sent;
std::map<uint16_t, std::string> _results;
};
-} // distributor
-} // storage
-
+}
diff --git a/storage/src/vespa/storage/distributor/operations/external/twophaseupdateoperation.cpp b/storage/src/vespa/storage/distributor/operations/external/twophaseupdateoperation.cpp
index 79ffee7430c..db120880267 100644
--- a/storage/src/vespa/storage/distributor/operations/external/twophaseupdateoperation.cpp
+++ b/storage/src/vespa/storage/distributor/operations/external/twophaseupdateoperation.cpp
@@ -4,13 +4,12 @@
#include "getoperation.h"
#include "putoperation.h"
#include "updateoperation.h"
-#include <vespa/document/fieldvalue/document.h>
-#include <vespa/document/datatype/documenttype.h>
-#include <vespa/document/select/parser.h>
+#include <vespa/storage/distributor/distributor_bucket_space.h>
#include <vespa/storageapi/message/persistence.h>
#include <vespa/storageapi/message/batch.h>
+#include <vespa/document/datatype/documenttype.h>
+#include <vespa/document/select/parser.h>
#include <vespa/vespalib/stllike/hash_map.hpp>
-#include <vespa/storage/distributor/distributor_bucket_space.h>
#include <vespa/log/log.h>
LOG_SETUP(".distributor.callback.twophaseupdate");
@@ -18,8 +17,7 @@ LOG_SETUP(".distributor.callback.twophaseupdate");
using namespace std::literals::string_literals;
using document::BucketSpace;
-namespace storage {
-namespace distributor {
+namespace storage::distributor {
TwoPhaseUpdateOperation::TwoPhaseUpdateOperation(
DistributorComponent& manager,
@@ -570,5 +568,4 @@ TwoPhaseUpdateOperation::onClose(DistributorMessageSender& sender) {
}
}
-} // distributor
-} // storage
+}
diff --git a/storage/src/vespa/storage/distributor/operations/external/visitoroperation.h b/storage/src/vespa/storage/distributor/operations/external/visitoroperation.h
index f35a9dcb3ec..b4f84d76649 100644
--- a/storage/src/vespa/storage/distributor/operations/external/visitoroperation.h
+++ b/storage/src/vespa/storage/distributor/operations/external/visitoroperation.h
@@ -11,11 +11,10 @@
namespace document { class Document; }
-namespace storage {
+namespace storage { class VisitorMetricSet; }
+namespace storage::lib { class ClusterState; }
-class VisitorMetricSet;
-
-namespace distributor {
+namespace storage::distributor {
class DistributorComponent;
class DistributorBucketSpace;
@@ -181,5 +180,3 @@ private:
};
}
-
-}
diff --git a/storage/src/vespa/storage/distributor/operations/idealstate/idealstateoperation.cpp b/storage/src/vespa/storage/distributor/operations/idealstate/idealstateoperation.cpp
index 2337129e375..52c8344b820 100644
--- a/storage/src/vespa/storage/distributor/operations/idealstate/idealstateoperation.cpp
+++ b/storage/src/vespa/storage/distributor/operations/idealstate/idealstateoperation.cpp
@@ -3,9 +3,8 @@
#include <vespa/storage/distributor/idealstatemanager.h>
#include <vespa/storage/distributor/pendingmessagetracker.h>
#include <vespa/storage/distributor/idealstatemetricsset.h>
-#include <vespa/storage/distributor/pendingmessagetracker.h>
#include <vespa/storage/distributor/distributor_bucket_space_repo.h>
-#include <vespa/storageapi/messageapi/maintenancecommand.h>
+#include <vespa/documentapi/loadtypes/loadtypeset.h>
#include <vespa/log/log.h>
LOG_SETUP(".distributor.operation");
@@ -26,17 +25,15 @@ const uint32_t IdealStateOperation::MAINTENANCE_MESSAGE_TYPES[] =
};
IdealStateOperation::IdealStateOperation(const BucketAndNodes& bucketAndNodes)
- : _manager(nullptr),
- _bucketSpace(nullptr),
- _bucketAndNodes(bucketAndNodes),
- _ok(true),
- _priority(255)
+ : _manager(nullptr),
+ _bucketSpace(nullptr),
+ _bucketAndNodes(bucketAndNodes),
+ _ok(true),
+ _priority(255)
{
}
-IdealStateOperation::~IdealStateOperation()
-{
-}
+IdealStateOperation::~IdealStateOperation() = default;
BucketAndNodes::BucketAndNodes(const document::Bucket &bucket, uint16_t node)
: _bucket(bucket)
@@ -108,8 +105,7 @@ IdealStateOperation::setCommandMeta(api::MaintenanceCommand& cmd) const
{
cmd.setPriority(_priority);
cmd.setReason(_detailedReason);
- cmd.setLoadType(
- (*_manager->getLoadTypes())["maintenance"]);
+ cmd.setLoadType((*_manager->getLoadTypes())["maintenance"]);
}
std::string
diff --git a/storage/src/vespa/storage/distributor/operations/idealstate/mergeoperation.cpp b/storage/src/vespa/storage/distributor/operations/idealstate/mergeoperation.cpp
index 271ac35968e..32ea695bd94 100644
--- a/storage/src/vespa/storage/distributor/operations/idealstate/mergeoperation.cpp
+++ b/storage/src/vespa/storage/distributor/operations/idealstate/mergeoperation.cpp
@@ -2,7 +2,7 @@
#include "mergeoperation.h"
#include <vespa/storage/distributor/idealstatemanager.h>
#include <vespa/storage/distributor/distributor_bucket_space.h>
-#include <array>
+#include <vespa/storage/distributor/pendingmessagetracker.h>
#include <vespa/log/bufferedlogger.h>
LOG_SETUP(".distributor.operation.idealstate.merge");
diff --git a/storage/src/vespa/storage/distributor/operations/idealstate/setbucketstateoperation.cpp b/storage/src/vespa/storage/distributor/operations/idealstate/setbucketstateoperation.cpp
index 1acb2dcc64b..6a87688c295 100644
--- a/storage/src/vespa/storage/distributor/operations/idealstate/setbucketstateoperation.cpp
+++ b/storage/src/vespa/storage/distributor/operations/idealstate/setbucketstateoperation.cpp
@@ -3,6 +3,7 @@
#include "setbucketstateoperation.h"
#include <vespa/storage/distributor/idealstatemanager.h>
#include <vespa/storage/distributor/distributor_bucket_space.h>
+#include <vespa/storageapi/message/bucket.h>
#include <vespa/log/log.h>
LOG_SETUP(".distributor.operation.idealstate.setactive");
diff --git a/storage/src/vespa/storage/distributor/operationtargetresolver.h b/storage/src/vespa/storage/distributor/operationtargetresolver.h
index 23e0fbbcba4..20666ea254c 100644
--- a/storage/src/vespa/storage/distributor/operationtargetresolver.h
+++ b/storage/src/vespa/storage/distributor/operationtargetresolver.h
@@ -10,8 +10,7 @@
#include <vespa/vdslib/state/node.h>
#include <vespa/vespalib/util/printable.h>
-namespace storage {
-namespace distributor {
+namespace storage::distributor {
class OperationTarget : public vespalib::AsciiPrintable
{
@@ -68,5 +67,4 @@ public:
const document::BucketId& id) = 0;
};
-} // distributor
-} // storage
+}
diff --git a/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.cpp b/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.cpp
index d2cb8fa4380..2f9430c49bb 100644
--- a/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.cpp
+++ b/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.cpp
@@ -10,11 +10,20 @@ LOG_SETUP(".storage.component.register");
namespace storage {
StorageComponentRegisterImpl::StorageComponentRegisterImpl()
- : _nodeType(0),
+ : _componentLock(),
+ _components(),
+ _clusterName(),
+ _nodeType(nullptr),
_index(0xffff),
+ _docTypeRepo(),
_loadTypes(new documentapi::LoadTypeSet),
- _nodeStateUpdater(0)
-{ }
+ _priorityConfig(),
+ _bucketIdFactory(),
+ _distribution(),
+ _nodeStateUpdater(nullptr),
+ _bucketSpacesConfig()
+{
+}
StorageComponentRegisterImpl::~StorageComponentRegisterImpl() { }
@@ -33,6 +42,7 @@ StorageComponentRegisterImpl::registerStorageComponent(StorageComponent& smc)
smc.setPriorityConfig(_priorityConfig);
smc.setBucketIdFactory(_bucketIdFactory);
smc.setDistribution(_distribution);
+ smc.enableMultipleBucketSpaces(_bucketSpacesConfig.enableMultipleBucketSpaces);
}
void
@@ -115,4 +125,14 @@ StorageComponentRegisterImpl::setDistribution(lib::Distribution::SP distribution
}
}
+void
+StorageComponentRegisterImpl::setBucketSpacesConfig(const BucketspacesConfig& config)
+{
+ vespalib::LockGuard lock(_componentLock);
+ _bucketSpacesConfig = config;
+ for (size_t i = 0; i < _components.size(); ++i) {
+ _components[i]->enableMultipleBucketSpaces(config.enableMultipleBucketSpaces);
+ }
+}
+
} // storage
diff --git a/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h b/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h
index 49387e2c2b5..afd9f11a88b 100644
--- a/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h
+++ b/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h
@@ -11,6 +11,7 @@
#include <vespa/document/repo/documenttyperepo.h>
#include <vespa/documentapi/loadtypes/loadtypeset.h>
#include <vespa/storage/common/storagecomponent.h>
+#include <vespa/storage/config/config-bucketspaces.h>
#include <vespa/storage/config/config-stor-prioritymapping.h>
#include <vespa/storageframework/defaultimplementation/component/componentregisterimpl.h>
#include <vespa/vdslib/distribution/distribution.h>
@@ -21,9 +22,9 @@ class StorageComponentRegisterImpl
: public virtual StorageComponentRegister,
public virtual framework::defaultimplementation::ComponentRegisterImpl
{
- typedef framework::defaultimplementation::ComponentRegisterImpl CompRegImpl;
- typedef StorageComponent::PriorityConfig PriorityConfig;
- //CompRegImpl _compReg;
+ using PriorityConfig = StorageComponent::PriorityConfig;
+ using BucketspacesConfig = vespa::config::content::core::internal::InternalBucketspacesType;
+
vespalib::Lock _componentLock;
std::vector<StorageComponent*> _components;
vespalib::string _clusterName;
@@ -35,6 +36,7 @@ class StorageComponentRegisterImpl
document::BucketIdFactory _bucketIdFactory;
lib::Distribution::SP _distribution;
NodeStateUpdater* _nodeStateUpdater;
+ BucketspacesConfig _bucketSpacesConfig;
public:
typedef std::unique_ptr<StorageComponentRegisterImpl> UP;
@@ -64,6 +66,7 @@ public:
virtual void setPriorityConfig(const PriorityConfig&);
virtual void setBucketIdFactory(const document::BucketIdFactory&);
virtual void setDistribution(lib::Distribution::SP);
+ virtual void setBucketSpacesConfig(const BucketspacesConfig&);
};
diff --git a/storage/src/vespa/storage/persistence/splitbitdetector.h b/storage/src/vespa/storage/persistence/splitbitdetector.h
index b3fc5bea566..6f1af6c5970 100644
--- a/storage/src/vespa/storage/persistence/splitbitdetector.h
+++ b/storage/src/vespa/storage/persistence/splitbitdetector.h
@@ -18,6 +18,7 @@
#pragma once
#include <vespa/persistence/spi/persistenceprovider.h>
+#include <vespa/vespalib/util/printable.h>
namespace storage {
diff --git a/storage/src/vespa/storage/storageserver/CMakeLists.txt b/storage/src/vespa/storage/storageserver/CMakeLists.txt
index c0238922a91..4fb3a5a0b99 100644
--- a/storage/src/vespa/storage/storageserver/CMakeLists.txt
+++ b/storage/src/vespa/storage/storageserver/CMakeLists.txt
@@ -6,6 +6,7 @@ vespa_add_library(storage_storageserver
changedbucketownershiphandler.cpp
communicationmanager.cpp
communicationmanagermetrics.cpp
+ configurable_bucket_resolver.cpp
distributornode.cpp
distributornodecontext.cpp
documentapiconverter.cpp
diff --git a/storage/src/vespa/storage/storageserver/communicationmanager.cpp b/storage/src/vespa/storage/storageserver/communicationmanager.cpp
index c19dc7cfd27..a2c923b93db 100644
--- a/storage/src/vespa/storage/storageserver/communicationmanager.cpp
+++ b/storage/src/vespa/storage/storageserver/communicationmanager.cpp
@@ -288,8 +288,7 @@ CommunicationManager::CommunicationManager(StorageComponentRegister& compReg, co
_count(0),
_configUri(configUri),
_closed(false),
- _bucketResolver(std::make_unique<PlaceHolderBucketResolver>()),
- _docApiConverter(configUri, *_bucketResolver)
+ _docApiConverter(configUri, std::make_shared<PlaceHolderBucketResolver>())
{
_component.registerMetricUpdateHook(*this, framework::SecondTime(5));
_component.registerMetric(_metrics);
diff --git a/storage/src/vespa/storage/storageserver/communicationmanager.h b/storage/src/vespa/storage/storageserver/communicationmanager.h
index f4f4aa5a236..b4508fbc9f9 100644
--- a/storage/src/vespa/storage/storageserver/communicationmanager.h
+++ b/storage/src/vespa/storage/storageserver/communicationmanager.h
@@ -170,7 +170,6 @@ private:
config::ConfigUri _configUri;
std::atomic<bool> _closed;
- std::unique_ptr<BucketResolver> _bucketResolver;
DocumentApiConverter _docApiConverter;
framework::Thread::UP _thread;
diff --git a/storage/src/vespa/storage/storageserver/configurable_bucket_resolver.cpp b/storage/src/vespa/storage/storageserver/configurable_bucket_resolver.cpp
new file mode 100644
index 00000000000..86c802a65cf
--- /dev/null
+++ b/storage/src/vespa/storage/storageserver/configurable_bucket_resolver.cpp
@@ -0,0 +1,36 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/document/base/documentid.h>
+#include <vespa/persistence/spi/fixed_bucket_spaces.h>
+#include <vespa/vespalib/util/exceptions.h>
+#include "configurable_bucket_resolver.h"
+
+namespace storage {
+
+document::Bucket ConfigurableBucketResolver::bucketFromId(const document::DocumentId& id) const {
+ auto iter = _type_to_space.find(id.getDocType());
+ if (iter != _type_to_space.end()) {
+ return document::Bucket(iter->second, document::BucketId(0));
+ }
+ throw spi::UnknownBucketSpaceException("Unknown bucket space mapping for document type '"
+ + id.getDocType() + "' in id: " + id.toString(), VESPA_STRLOC);
+}
+
+document::BucketSpace ConfigurableBucketResolver::bucketSpaceFromName(const vespalib::string& name) const {
+ return spi::FixedBucketSpaces::from_string(name);
+}
+
+vespalib::string ConfigurableBucketResolver::nameFromBucketSpace(const document::BucketSpace& space) const {
+ return spi::FixedBucketSpaces::to_string(space);
+}
+
+std::shared_ptr<ConfigurableBucketResolver> ConfigurableBucketResolver::from_config(
+ const vespa::config::content::core::BucketspacesConfig& config) {
+ ConfigurableBucketResolver::BucketSpaceMapping type_to_space;
+ for (auto& mapping : config.documenttype) {
+ type_to_space.emplace(mapping.name, spi::FixedBucketSpaces::from_string(mapping.bucketspace));
+ }
+ return std::make_shared<ConfigurableBucketResolver>(std::move(type_to_space));
+}
+
+}
diff --git a/storage/src/vespa/storage/storageserver/configurable_bucket_resolver.h b/storage/src/vespa/storage/storageserver/configurable_bucket_resolver.h
new file mode 100644
index 00000000000..acebd9777fb
--- /dev/null
+++ b/storage/src/vespa/storage/storageserver/configurable_bucket_resolver.h
@@ -0,0 +1,36 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#pragma once
+
+#include <vespa/storage/config/config-bucketspaces.h>
+#include <vespa/storage/common/bucket_resolver.h>
+#include <vespa/vespalib/stllike/hash_fun.h>
+#include <memory>
+#include <unordered_map>
+
+namespace storage {
+
+/**
+ * Immutable implementation of BucketResolver which maintains an explicit
+ * mapping from document type to bucket space.
+ *
+ * If an unknown document type or bucket space is given as an argument,
+ * an spi::UnknownBucketSpaceException is thrown.
+ */
+class ConfigurableBucketResolver : public BucketResolver {
+public:
+ using BucketSpaceMapping = std::unordered_map<vespalib::string, document::BucketSpace, vespalib::hash<vespalib::string>>;
+ const BucketSpaceMapping _type_to_space;
+public:
+ explicit ConfigurableBucketResolver(BucketSpaceMapping type_to_space)
+ : _type_to_space(std::move(type_to_space))
+ {}
+
+ document::Bucket bucketFromId(const document::DocumentId&) const override;
+ document::BucketSpace bucketSpaceFromName(const vespalib::string& name) const override;
+ vespalib::string nameFromBucketSpace(const document::BucketSpace& space) const override;
+
+ static std::shared_ptr<ConfigurableBucketResolver> from_config(
+ const vespa::config::content::core::BucketspacesConfig& config);
+};
+
+} \ No newline at end of file
diff --git a/storage/src/vespa/storage/storageserver/documentapiconverter.cpp b/storage/src/vespa/storage/storageserver/documentapiconverter.cpp
index c2761b3d832..09ca9924891 100644
--- a/storage/src/vespa/storage/storageserver/documentapiconverter.cpp
+++ b/storage/src/vespa/storage/storageserver/documentapiconverter.cpp
@@ -24,9 +24,9 @@ using document::BucketSpace;
namespace storage {
DocumentApiConverter::DocumentApiConverter(const config::ConfigUri &configUri,
- const BucketResolver &bucketResolver)
+ std::shared_ptr<const BucketResolver> bucketResolver)
: _priConverter(std::make_unique<PriorityConverter>(configUri)),
- _bucketResolver(bucketResolver)
+ _bucketResolver(std::move(bucketResolver))
{}
DocumentApiConverter::~DocumentApiConverter() {}
@@ -42,7 +42,7 @@ DocumentApiConverter::toStorageAPI(documentapi::DocumentMessage& fromMsg,
case DocumentProtocol::MESSAGE_PUTDOCUMENT:
{
documentapi::PutDocumentMessage& from(static_cast<documentapi::PutDocumentMessage&>(fromMsg));
- document::Bucket bucket = _bucketResolver.bucketFromId(from.getDocument().getId());
+ document::Bucket bucket = bucketResolver()->bucketFromId(from.getDocument().getId());
auto to = std::make_unique<api::PutCommand>(bucket, from.stealDocument(), from.getTimestamp());
to->setCondition(from.getCondition());
toMsg = std::move(to);
@@ -51,7 +51,7 @@ DocumentApiConverter::toStorageAPI(documentapi::DocumentMessage& fromMsg,
case DocumentProtocol::MESSAGE_UPDATEDOCUMENT:
{
documentapi::UpdateDocumentMessage& from(static_cast<documentapi::UpdateDocumentMessage&>(fromMsg));
- document::Bucket bucket = _bucketResolver.bucketFromId(from.getDocumentUpdate().getId());
+ document::Bucket bucket = bucketResolver()->bucketFromId(from.getDocumentUpdate().getId());
auto to = std::make_unique<api::UpdateCommand>(bucket, from.stealDocumentUpdate(), from.getNewTimestamp());
to->setOldTimestamp(from.getOldTimestamp());
to->setCondition(from.getCondition());
@@ -61,7 +61,7 @@ DocumentApiConverter::toStorageAPI(documentapi::DocumentMessage& fromMsg,
case DocumentProtocol::MESSAGE_REMOVEDOCUMENT:
{
documentapi::RemoveDocumentMessage& from(static_cast<documentapi::RemoveDocumentMessage&>(fromMsg));
- auto to = std::make_unique<api::RemoveCommand>(_bucketResolver.bucketFromId(from.getDocumentId()), from.getDocumentId(), 0);
+ auto to = std::make_unique<api::RemoveCommand>(bucketResolver()->bucketFromId(from.getDocumentId()), from.getDocumentId(), 0);
to->setCondition(from.getCondition());
toMsg = std::move(to);
break;
@@ -69,14 +69,14 @@ DocumentApiConverter::toStorageAPI(documentapi::DocumentMessage& fromMsg,
case DocumentProtocol::MESSAGE_GETDOCUMENT:
{
documentapi::GetDocumentMessage& from(static_cast<documentapi::GetDocumentMessage&>(fromMsg));
- auto to = std::make_unique<api::GetCommand>(_bucketResolver.bucketFromId(from.getDocumentId()), from.getDocumentId(), from.getFieldSet());
+ auto to = std::make_unique<api::GetCommand>(bucketResolver()->bucketFromId(from.getDocumentId()), from.getDocumentId(), from.getFieldSet());
toMsg.reset(to.release());
break;
}
case DocumentProtocol::MESSAGE_CREATEVISITOR:
{
documentapi::CreateVisitorMessage& from(static_cast<documentapi::CreateVisitorMessage&>(fromMsg));
- auto to = std::make_unique<api::CreateVisitorCommand>(_bucketResolver.bucketSpaceFromName(from.getBucketSpace()),
+ auto to = std::make_unique<api::CreateVisitorCommand>(bucketResolver()->bucketSpaceFromName(from.getBucketSpace()),
from.getLibraryName(), from.getInstanceId(),
from.getDocumentSelection());
@@ -118,14 +118,14 @@ DocumentApiConverter::toStorageAPI(documentapi::DocumentMessage& fromMsg,
case DocumentProtocol::MESSAGE_STATBUCKET:
{
documentapi::StatBucketMessage& from(static_cast<documentapi::StatBucketMessage&>(fromMsg));
- document::Bucket bucket(_bucketResolver.bucketSpaceFromName(from.getBucketSpace()), from.getBucketId());
+ document::Bucket bucket(bucketResolver()->bucketSpaceFromName(from.getBucketSpace()), from.getBucketId());
toMsg = std::make_unique<api::StatBucketCommand>(bucket, from.getDocumentSelection());
break;
}
case DocumentProtocol::MESSAGE_GETBUCKETLIST:
{
documentapi::GetBucketListMessage& from(static_cast<documentapi::GetBucketListMessage&>(fromMsg));
- document::Bucket bucket(_bucketResolver.bucketSpaceFromName(from.getBucketSpace()), from.getBucketId());
+ document::Bucket bucket(bucketResolver()->bucketSpaceFromName(from.getBucketSpace()), from.getBucketId());
toMsg = std::make_unique<api::GetBucketListCommand>(bucket);
break;
}
@@ -145,7 +145,7 @@ DocumentApiConverter::toStorageAPI(documentapi::DocumentMessage& fromMsg,
case DocumentProtocol::MESSAGE_REMOVELOCATION:
{
documentapi::RemoveLocationMessage& from(static_cast<documentapi::RemoveLocationMessage&>(fromMsg));
- document::Bucket bucket(_bucketResolver.bucketSpaceFromName(from.getBucketSpace()), document::BucketId(0));
+ document::Bucket bucket(bucketResolver()->bucketSpaceFromName(from.getBucketSpace()), document::BucketId(0));
api::RemoveLocationCommand::UP to(new api::RemoveLocationCommand(from.getDocumentSelection(), bucket));
toMsg.reset(to.release());
break;
@@ -298,7 +298,7 @@ DocumentApiConverter::toDocumentAPI(api::StorageCommand& fromMsg, const document
documentapi::CreateVisitorMessage::UP to(
new documentapi::CreateVisitorMessage(from.getLibraryName(), from.getInstanceId(),
from.getControlDestination(), from.getDataDestination()));
- to->setBucketSpace(_bucketResolver.nameFromBucketSpace(from.getBucketSpace()));
+ to->setBucketSpace(bucketResolver()->nameFromBucketSpace(from.getBucketSpace()));
to->setDocumentSelection(from.getDocumentSelection());
to->setMaximumPendingReplyCount(from.getMaximumPendingReplyCount());
to->setParameters(from.getParameters());
@@ -325,7 +325,7 @@ DocumentApiConverter::toDocumentAPI(api::StorageCommand& fromMsg, const document
{
api::StatBucketCommand& from(static_cast<api::StatBucketCommand&>(fromMsg));
auto statMsg = std::make_unique<documentapi::StatBucketMessage>(from.getBucket().getBucketId(), from.getDocumentSelection());
- statMsg->setBucketSpace(_bucketResolver.nameFromBucketSpace(from.getBucket().getBucketSpace()));
+ statMsg->setBucketSpace(bucketResolver()->nameFromBucketSpace(from.getBucket().getBucketSpace()));
toMsg = std::move(statMsg);
break;
}
@@ -404,4 +404,14 @@ DocumentApiConverter::transferReplyState(api::StorageReply& fromMsg, mbus::Reply
}
}
+std::shared_ptr<const BucketResolver> DocumentApiConverter::bucketResolver() const {
+ std::lock_guard lock(_mutex);
+ return _bucketResolver;
+}
+
+void DocumentApiConverter::setBucketResolver(std::shared_ptr<const BucketResolver> resolver) {
+ std::lock_guard lock(_mutex);
+ _bucketResolver = std::move(resolver);
+}
+
} // storage
diff --git a/storage/src/vespa/storage/storageserver/documentapiconverter.h b/storage/src/vespa/storage/storageserver/documentapiconverter.h
index 5310bcd0127..546bc86a007 100644
--- a/storage/src/vespa/storage/storageserver/documentapiconverter.h
+++ b/storage/src/vespa/storage/storageserver/documentapiconverter.h
@@ -4,6 +4,7 @@
#include <vespa/documentapi/messagebus/messages/documentmessage.h>
#include <vespa/documentapi/messagebus/messages/documentreply.h>
#include <vespa/document/repo/documenttyperepo.h>
+#include <mutex>
namespace config { class ConfigUri; }
namespace storage {
@@ -23,7 +24,7 @@ class DocumentApiConverter
{
public:
DocumentApiConverter(const config::ConfigUri &configUri,
- const BucketResolver &bucketResolver);
+ std::shared_ptr<const BucketResolver> bucketResolver);
~DocumentApiConverter();
std::unique_ptr<api::StorageCommand> toStorageAPI(documentapi::DocumentMessage& msg, const document::DocumentTypeRepo::SP &repo);
@@ -31,9 +32,14 @@ public:
void transferReplyState(storage::api::StorageReply& from, mbus::Reply& to);
std::unique_ptr<mbus::Message> toDocumentAPI(api::StorageCommand& cmd, const document::DocumentTypeRepo::SP &repo);
const PriorityConverter& getPriorityConverter() const { return *_priConverter; }
+
+ // BucketResolver getter and setter are both thread safe.
+ std::shared_ptr<const BucketResolver> bucketResolver() const;
+ void setBucketResolver(std::shared_ptr<const BucketResolver> resolver);
private:
+ mutable std::mutex _mutex;
std::unique_ptr<PriorityConverter> _priConverter;
- const BucketResolver &_bucketResolver;
+ std::shared_ptr<const BucketResolver> _bucketResolver;
};
} // namespace storage
diff --git a/storage/src/vespa/storage/storageserver/storagenode.cpp b/storage/src/vespa/storage/storageserver/storagenode.cpp
index ba1556bd3b9..d60f46e5a07 100644
--- a/storage/src/vespa/storage/storageserver/storagenode.cpp
+++ b/storage/src/vespa/storage/storageserver/storagenode.cpp
@@ -76,12 +76,38 @@ StorageNode::StorageNode(
std::unique_ptr<HostInfo> hostInfo,
RunMode mode)
: _singleThreadedDebugMode(mode == SINGLE_THREADED_TEST_MODE),
+ _configFetcher(),
_hostInfo(std::move(hostInfo)),
_context(context),
_generationFetcher(generationFetcher),
+ _rootFolder(),
_attemptedStopped(false),
+ _pidFile(),
+ _statusWebServer(),
+ _metrics(),
+ _metricManager(),
+ _deadLockDetector(),
+ _statusMetrics(),
+ _stateReporter(),
+ _stateManager(),
+ _chain(),
+ _configLock(),
+ _initial_config_mutex(),
+ _serverConfig(),
+ _clusterConfig(),
+ _distributionConfig(),
+ _priorityConfig(),
+ _doctypesConfig(),
+ _bucketSpacesConfig(),
+ _newServerConfig(),
+ _newClusterConfig(),
+ _newDistributionConfig(),
+ _newPriorityConfig(),
+ _newDoctypesConfig(),
+ _newBucketSpacesConfig(),
+ _component(),
_configUri(configUri),
- _communicationManager(0)
+ _communicationManager(nullptr)
{
}
@@ -93,6 +119,7 @@ StorageNode::subscribeToConfigs()
_configFetcher->subscribe<UpgradingConfig>(_configUri.getConfigId(), this);
_configFetcher->subscribe<StorServerConfig>(_configUri.getConfigId(), this);
_configFetcher->subscribe<StorPrioritymappingConfig>(_configUri.getConfigId(), this);
+ _configFetcher->subscribe<BucketspacesConfig>(_configUri.getConfigId(), this);
_configFetcher->start();
@@ -101,6 +128,7 @@ StorageNode::subscribeToConfigs()
_clusterConfig = std::move(_newClusterConfig);
_distributionConfig = std::move(_newDistributionConfig);
_priorityConfig = std::move(_newPriorityConfig);
+ _bucketSpacesConfig = std::move(_newBucketSpacesConfig);
}
void
@@ -127,6 +155,7 @@ StorageNode::initialize()
_context.getComponentRegister().setBucketIdFactory(document::BucketIdFactory());
_context.getComponentRegister().setDistribution(make_shared<lib::Distribution>(*_distributionConfig));
_context.getComponentRegister().setPriorityConfig(*_priorityConfig);
+ _context.getComponentRegister().setBucketSpacesConfig(*_bucketSpacesConfig);
_metrics.reset(new StorageMetricSet);
_component.reset(new StorageComponent(_context.getComponentRegister(), "storagenode"));
@@ -315,6 +344,11 @@ StorageNode::handleLiveConfigUpdate(const InitialGuard & initGuard)
_priorityConfig = std::move(_newPriorityConfig);
_context.getComponentRegister().setPriorityConfig(*_priorityConfig);
}
+ if (_newBucketSpacesConfig) {
+ _bucketSpacesConfig = std::move(_newBucketSpacesConfig);
+ _context.getComponentRegister().setBucketSpacesConfig(*_bucketSpacesConfig);
+ // TODO: Add new bucket space resolver to document api converter
+ }
}
void
@@ -430,7 +464,7 @@ void StorageNode::configure(std::unique_ptr<StorServerConfig> config)
// updates
{
vespalib::LockGuard configLockGuard(_configLock);
- _newServerConfig.reset(config.release());
+ _newServerConfig = std::move(config);
}
if (_serverConfig) {
InitialGuard concurrent_config_guard(_initial_config_mutex);
@@ -447,7 +481,7 @@ StorageNode::configure(std::unique_ptr<UpgradingConfig> config)
// updates
{
vespalib::LockGuard configLockGuard(_configLock);
- _newClusterConfig.reset(config.release());
+ _newClusterConfig = std::move(config);
}
if (_clusterConfig) {
InitialGuard concurrent_config_guard(_initial_config_mutex);
@@ -464,7 +498,7 @@ StorageNode::configure(std::unique_ptr<StorDistributionConfig> config)
// updates
{
vespalib::LockGuard configLockGuard(_configLock);
- _newDistributionConfig.reset(config.release());
+ _newDistributionConfig = std::move(config);
}
if (_distributionConfig) {
InitialGuard concurrent_config_guard(_initial_config_mutex);
@@ -477,7 +511,7 @@ StorageNode::configure(std::unique_ptr<StorPrioritymappingConfig> config)
{
{
vespalib::LockGuard configLockGuard(_configLock);
- _newPriorityConfig.reset(config.release());
+ _newPriorityConfig = std::move(config);
}
if (_priorityConfig) {
InitialGuard concurrent_config_guard(_initial_config_mutex);
@@ -485,15 +519,16 @@ StorageNode::configure(std::unique_ptr<StorPrioritymappingConfig> config)
}
}
-void StorageNode::configure(std::unique_ptr<document::DocumenttypesConfig> config,
- bool hasChanged, int64_t generation)
+void
+StorageNode::configure(std::unique_ptr<document::DocumenttypesConfig> config,
+ bool hasChanged, int64_t generation)
{
(void) generation;
if (!hasChanged)
return;
{
vespalib::LockGuard configLockGuard(_configLock);
- _newDoctypesConfig.reset(config.release());
+ _newDoctypesConfig = std::move(config);
}
if (_doctypesConfig) {
InitialGuard concurrent_config_guard(_initial_config_mutex);
@@ -501,6 +536,19 @@ void StorageNode::configure(std::unique_ptr<document::DocumenttypesConfig> confi
}
}
+void
+StorageNode::configure(std::unique_ptr<BucketspacesConfig> config)
+{
+ {
+ vespalib::LockGuard configLockGuard(_configLock);
+ _newBucketSpacesConfig = std::move(config);
+ }
+ if (_bucketSpacesConfig) {
+ InitialGuard concurrent_config_guard(_initial_config_mutex);
+ handleLiveConfigUpdate(concurrent_config_guard);
+ }
+}
+
bool
StorageNode::attemptedStopped() const
{
diff --git a/storage/src/vespa/storage/storageserver/storagenode.h b/storage/src/vespa/storage/storageserver/storagenode.h
index e9d3004be68..a07d1c0c534 100644
--- a/storage/src/vespa/storage/storageserver/storagenode.h
+++ b/storage/src/vespa/storage/storageserver/storagenode.h
@@ -12,20 +12,19 @@
#pragma once
-#include <vespa/storage/storageutil/resumeguard.h>
-#include <vespa/storage/common/doneinitializehandler.h>
-#include <vespa/storageframework/generic/metric/metricupdatehook.h>
-#include <vespa/storageframework/defaultimplementation/component/componentregisterimpl.h>
-
-#include <vespa/config/subscription/configuri.h>
-#include <vespa/config/helper/ifetchercallback.h>
+#include <vespa/config-stor-distribution.h>
+#include <vespa/config-upgrading.h>
#include <vespa/config/helper/configfetcher.h>
-
+#include <vespa/config/helper/ifetchercallback.h>
+#include <vespa/config/subscription/configuri.h>
+#include <vespa/document/config/config-documenttypes.h>
+#include <vespa/storage/common/doneinitializehandler.h>
+#include <vespa/storage/config/config-bucketspaces.h>
#include <vespa/storage/config/config-stor-prioritymapping.h>
#include <vespa/storage/config/config-stor-server.h>
-#include <vespa/document/config/config-documenttypes.h>
-#include <vespa/config-upgrading.h>
-#include <vespa/config-stor-distribution.h>
+#include <vespa/storage/storageutil/resumeguard.h>
+#include <vespa/storageframework/defaultimplementation/component/componentregisterimpl.h>
+#include <vespa/storageframework/generic/metric/metricupdatehook.h>
#include <mutex>
namespace document { class DocumentTypeRepo; }
@@ -54,6 +53,7 @@ class StorageNode : private config::IFetcherCallback<vespa::config::content::cor
private config::IFetcherCallback<vespa::config::content::StorDistributionConfig>,
private config::IFetcherCallback<vespa::config::content::UpgradingConfig>,
private config::IFetcherCallback<vespa::config::content::core::StorPrioritymappingConfig>,
+ private config::IFetcherCallback<vespa::config::content::core::BucketspacesConfig>,
private framework::MetricUpdateHook,
private DoneInitializeHandler,
private framework::defaultimplementation::ShutdownListener
@@ -101,6 +101,7 @@ protected:
using UpgradingConfig = vespa::config::content::UpgradingConfig;
using StorDistributionConfig = vespa::config::content::StorDistributionConfig;
using StorPrioritymappingConfig = vespa::config::content::core::StorPrioritymappingConfig;
+ using BucketspacesConfig = vespa::config::content::core::BucketspacesConfig;
private:
bool _singleThreadedDebugMode;
// Subscriptions to config
@@ -137,6 +138,7 @@ private:
void configure(std::unique_ptr<StorPrioritymappingConfig>) override;
virtual void configure(std::unique_ptr<document::DocumenttypesConfig> config,
bool hasChanged, int64_t generation);
+ void configure(std::unique_ptr<BucketspacesConfig>) override;
void updateUpgradeFlag(const UpgradingConfig&);
protected:
@@ -151,12 +153,14 @@ protected:
std::unique_ptr<StorDistributionConfig> _distributionConfig;
std::unique_ptr<StorPrioritymappingConfig> _priorityConfig;
std::unique_ptr<document::DocumenttypesConfig> _doctypesConfig;
+ std::unique_ptr<BucketspacesConfig> _bucketSpacesConfig;
// New configs gotten that has yet to have been handled
std::unique_ptr<StorServerConfig> _newServerConfig;
std::unique_ptr<UpgradingConfig> _newClusterConfig;
std::unique_ptr<StorDistributionConfig> _newDistributionConfig;
std::unique_ptr<StorPrioritymappingConfig> _newPriorityConfig;
std::unique_ptr<document::DocumenttypesConfig> _newDoctypesConfig;
+ std::unique_ptr<BucketspacesConfig> _newBucketSpacesConfig;
std::unique_ptr<StorageComponent> _component;
config::ConfigUri _configUri;
CommunicationManager* _communicationManager;
diff --git a/storageapi/src/vespa/storageapi/message/state.h b/storageapi/src/vespa/storageapi/message/state.h
index e8062c71d22..746d92fce6b 100644
--- a/storageapi/src/vespa/storageapi/message/state.h
+++ b/storageapi/src/vespa/storageapi/message/state.h
@@ -6,8 +6,7 @@
#include <vespa/storageapi/messageapi/storagereply.h>
#include <vespa/vdslib/state/clusterstate.h>
-namespace storage {
-namespace api {
+namespace storage::api {
/**
* @class GetNodeStateCommand
@@ -90,5 +89,4 @@ public:
DECLARE_STORAGEREPLY(SetSystemStateReply, onSetSystemStateReply)
};
-} // api
-} // storage
+}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java
index 7874dcb24ab..16a541f939c 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java
@@ -5,6 +5,7 @@ import com.google.common.annotations.Beta;
import com.yahoo.vespa.http.client.Result;
import com.yahoo.vespa.http.client.config.Endpoint;
import com.yahoo.vespa.http.client.core.Document;
+import com.yahoo.vespa.http.client.core.Exceptions;
import com.yahoo.vespa.http.client.core.operationProcessor.EndPointResultFactory;
import com.yahoo.vespa.http.client.core.EndpointResult;
import com.yahoo.vespa.http.client.core.ServerResponseException;
@@ -318,29 +319,28 @@ class IOThread implements Runnable, AutoCloseable {
successfullHandshakes.getAndIncrement();
} catch (ServerResponseException ser) {
executeProblemsCounter.incrementAndGet();
- log.log(Level.INFO, "Handshake did not work out " + endpoint, ser.getMessage());
+ log.log(Level.INFO, "Handshake did not work out " + endpoint, Exceptions.toMessageString(ser));
drainFirstDocumentsInQueueIfOld();
return ThreadState.CONNECTED;
} catch (Throwable throwable) { // This cover IOException as well
executeProblemsCounter.incrementAndGet();
- log.log(Level.INFO, "Problem with Handshake " + endpoint, throwable.getMessage());
+ log.log(Level.INFO, "Problem with Handshake " + endpoint, Exceptions.toMessageString(throwable));
drainFirstDocumentsInQueueIfOld();
client.close();
return ThreadState.DISCONNECTED;
}
return ThreadState.SESSION_SYNCED;
case SESSION_SYNCED:
- final int maxWaitTimeMilliSecs = 100;
try {
- ProcessResponse processResponse = pullAndProcessData(maxWaitTimeMilliSecs);
+ ProcessResponse processResponse = pullAndProcessData(100);
gatewayThrottler.handleCall(processResponse.transitiveErrorCount);
}
catch (ServerResponseException ser) {
- log.info("Problems while handing data over to gateway " + endpoint + " " + ser.getMessage());
+ log.info("Problems while handing data over to gateway " + endpoint + ": " + Exceptions.toMessageString(ser));
return ThreadState.CONNECTED;
}
catch (Throwable e) { // Covers IOException as well
- log.info("Problems while handing data over to gateway " + endpoint + " " + e.getMessage());
+ log.info("Problems while handing data over to gateway " + endpoint + ": " + Exceptions.toMessageString(e));
client.close();
return ThreadState.DISCONNECTED;
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java
index 0a9fe72552c..5907694f55a 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java
@@ -177,10 +177,14 @@ public class OperationProcessor {
docSendInfoByOperationId.remove(endpointResult.getOperationId());
String documentId = documentSendInfo.getDocument().getDocumentId();
- inflightDocumentIds.remove(documentId);
-
+ /**
+ * If we got a pending operation against this document
+ * dont't remove it from inflightDocuments and send blocked document operation
+ */
List<Document> blockedDocuments = blockedDocumentsByDocumentId.get(documentId);
- if (! blockedDocuments.isEmpty()) {
+ if (blockedDocuments.isEmpty()) {
+ inflightDocumentIds.remove(documentId);
+ } else {
sendToClusters(blockedDocuments.remove(0));
}
return result;
diff --git a/vespajlib/src/main/java/com/yahoo/net/HostName.java b/vespajlib/src/main/java/com/yahoo/net/HostName.java
index 37f7fe80246..157239e456f 100644
--- a/vespajlib/src/main/java/com/yahoo/net/HostName.java
+++ b/vespajlib/src/main/java/com/yahoo/net/HostName.java
@@ -27,7 +27,7 @@ public class HostName {
private static final Logger logger = Logger.getLogger(HostName.class.getName());
- private static String cachedHostName = null;
+ private static String preferredHostName = null;
/**
* Return a public and fully qualified hostname for localhost that resolves to an IP address on
@@ -38,14 +38,14 @@ public class HostName {
* @throws RuntimeException if accessing the network or the 'hostname' command fails
*/
public static synchronized String getLocalhost() {
- if (cachedHostName == null) {
+ if (preferredHostName == null) {
try {
- cachedHostName = getPreferredHostName();
+ preferredHostName = getPreferredHostName();
} catch (Exception e) {
throw new RuntimeException("Failed to find a preferred hostname", e);
}
}
- return cachedHostName;
+ return preferredHostName;
}
private static String getPreferredHostName() throws Exception {
@@ -178,4 +178,7 @@ public class HostName {
}
}
+ public static void setHostNameForTestingOnly(String hostName) {
+ preferredHostName = hostName;
+ }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index 00e106dd035..01bf082d32f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -7,13 +7,13 @@ import java.util.Arrays;
/**
* The sizes of a set of dimensions.
- *
+ *
* @author bratseth
*/
@Beta
public final class DimensionSizes {
- private final int[] sizes;
+ private final long[] sizes;
private DimensionSizes(Builder builder) {
this.sizes = builder.sizes;
@@ -25,15 +25,15 @@ public final class DimensionSizes {
*
* @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one
*/
- public int size(int dimensionIndex) { return sizes[dimensionIndex]; }
+ public long size(int dimensionIndex) { return sizes[dimensionIndex]; }
/** Returns the number of dimensions this provides the size of */
public int dimensions() { return sizes.length; }
/** Returns the product of the sizes of this */
- public int totalSize() {
- int productSize = 1;
- for (int dimensionSize : sizes )
+ public long totalSize() {
+ long productSize = 1;
+ for (long dimensionSize : sizes )
productSize *= dimensionSize;
return productSize;
}
@@ -48,19 +48,19 @@ public final class DimensionSizes {
@Override
public int hashCode() { return Arrays.hashCode(sizes); }
- /**
+ /**
* Builder of a set of dimension sizes.
* Dimensions whose size is not set before building will get size 0.
*/
public final static class Builder {
- private int[] sizes;
+ private long[] sizes;
public Builder(int dimensions) {
- this.sizes = new int[dimensions];
+ this.sizes = new long[dimensions];
}
- public Builder set(int dimensionIndex, int size) {
+ public Builder set(int dimensionIndex, long size) {
sizes[dimensionIndex] = size;
return this;
}
@@ -70,7 +70,7 @@ public final class DimensionSizes {
*
* @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one
*/
- public int size(int dimensionIndex) { return sizes[dimensionIndex]; }
+ public long size(int dimensionIndex) { return sizes[dimensionIndex]; }
/** Returns the number of dimensions this provides the size of */
public int dimensions() { return sizes.length; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index c207dabca3a..7130c053e9f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -25,12 +25,12 @@ public class IndexedTensor implements Tensor {
/** The prescribed and possibly abstract type this is an instance of */
private final TensorType type;
-
+
/** The sizes of the dimensions of this in the order of the dimensions of the type */
private final DimensionSizes dimensionSizes;
-
+
private final double[] values;
-
+
private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) {
this.type = type;
this.dimensionSizes = dimensionSizes;
@@ -38,13 +38,13 @@ public class IndexedTensor implements Tensor {
}
@Override
- public int size() {
+ public long size() {
return values.length;
}
/**
- * Returns an iterator over the cells of this.
- * Cells are returned in order of increasing indexes in each dimension, increasing
+ * Returns an iterator over the cells of this.
+ * Cells are returned in order of increasing indexes in each dimension, increasing
* indexes of later dimensions in the dimension type before earlier.
*/
@Override
@@ -55,10 +55,10 @@ public class IndexedTensor implements Tensor {
/** Returns an iterator over all the cells in this tensor which matches the given partial address */
// TODO: Move up to Tensor and create a mixed tensor which can implement it (and subspace iterators) efficiently
public SubspaceIterator cellIterator(PartialAddress partialAddress, DimensionSizes iterationSizes) {
- int[] startAddress = new int[type().dimensions().size()];
+ long[] startAddress = new long[type().dimensions().size()];
List<Integer> iterateDimensions = new ArrayList<>();
for (int i = 0; i < type().dimensions().size(); i++) {
- int partialAddressLabel = partialAddress.intLabel(type.dimensions().get(i).name());
+ long partialAddressLabel = partialAddress.numericLabel(type.dimensions().get(i).name());
if (partialAddressLabel >= 0) // iterate at this label
startAddress[i] = partialAddressLabel;
else // iterate over this dimension
@@ -69,7 +69,7 @@ public class IndexedTensor implements Tensor {
/**
* Returns an iterator over the values of this.
- * Values are returned in order of increasing indexes in each dimension, increasing
+ * Values are returned in order of increasing indexes in each dimension, increasing
* indexes of later dimensions in the dimension type before earlier.
*/
@Override
@@ -81,7 +81,7 @@ public class IndexedTensor implements Tensor {
* Returns an iterator over value iterators where the outer iterator is over each unique value of the dimensions
* given and the inner iterator is over each unique value of the rest of the dimensions, in the same order as
* other iterator.
- *
+ *
* @param dimensions the names of the dimensions of the superspace
* @param sizes the size of each dimension in the space we are returning values for, containing
* one value per dimension of this tensor (in order). Each size may be the same or smaller
@@ -96,14 +96,14 @@ public class IndexedTensor implements Tensor {
return subspaceIterator(dimensions, dimensionSizes);
}
- /**
+ /**
* Returns the value at the given indexes
- *
+ *
* @param indexes the indexes into the dimensions of this. Must be one number per dimension of this
* @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given
*/
- public double get(int ... indexes) {
- return values[toValueIndex(indexes, dimensionSizes)];
+ public double get(long ... indexes) {
+ return values[(int)toValueIndex(indexes, dimensionSizes)];
}
/** Returns the value at this address, or NaN if there is no value at this address */
@@ -111,20 +111,20 @@ public class IndexedTensor implements Tensor {
public double get(TensorAddress address) {
// optimize for fast lookup within bounds:
try {
- return values[toValueIndex(address, dimensionSizes)];
+ return values[(int)toValueIndex(address, dimensionSizes)];
}
catch (IndexOutOfBoundsException e) {
return Double.NaN;
}
}
- private double get(int valueIndex) { return values[valueIndex]; }
-
- private static int toValueIndex(int[] indexes, DimensionSizes sizes) {
+ private double get(long valueIndex) { return values[(int)valueIndex]; }
+
+ private static long toValueIndex(long[] indexes, DimensionSizes sizes) {
if (indexes.length == 1) return indexes[0]; // for speed
if (indexes.length == 0) return 0; // for speed
- int valueIndex = 0;
+ long valueIndex = 0;
for (int i = 0; i < indexes.length; i++) {
if (indexes[i] >= sizes.size(i)) {
throw new IndexOutOfBoundsException();
@@ -134,21 +134,21 @@ public class IndexedTensor implements Tensor {
return valueIndex;
}
- private static int toValueIndex(TensorAddress address, DimensionSizes sizes) {
+ private static long toValueIndex(TensorAddress address, DimensionSizes sizes) {
if (address.isEmpty()) return 0;
- int valueIndex = 0;
+ long valueIndex = 0;
for (int i = 0; i < address.size(); i++) {
- if (address.intLabel(i) >= sizes.size(i)) {
+ if (address.numericLabel(i) >= sizes.size(i)) {
throw new IndexOutOfBoundsException();
}
- valueIndex += productOfDimensionsAfter(i, sizes) * address.intLabel(i);
+ valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i);
}
return valueIndex;
}
- private static int productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) {
- int product = 1;
+ private static long productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) {
+ long product = 1;
for (int i = afterIndex + 1; i < sizes.dimensions(); i++)
product *= sizes.size(i);
return product;
@@ -165,22 +165,22 @@ public class IndexedTensor implements Tensor {
public Map<TensorAddress, Double> cells() {
if (dimensionSizes.dimensions() == 0)
return Collections.singletonMap(TensorAddress.of(), values[0]);
-
+
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
- for (int i = 0; i < values.length; i++) {
+ for (long i = 0; i < values.length; i++) {
indexes.next();
- builder.put(indexes.toAddress(), values[i]);
+ builder.put(indexes.toAddress(), values[(int)i]);
}
return builder.build();
}
-
+
@Override
public int hashCode() { return Arrays.hashCode(values); }
@Override
public String toString() { return Tensor.toStandardString(this); }
-
+
@Override
public boolean equals(Object other) {
if ( ! ( other instanceof Tensor)) return false;
@@ -188,9 +188,9 @@ public class IndexedTensor implements Tensor {
}
public abstract static class Builder implements Tensor.Builder {
-
+
final TensorType type;
-
+
private Builder(TensorType type) {
this.type = type;
}
@@ -202,7 +202,7 @@ public class IndexedTensor implements Tensor {
return new UnboundBuilder(type);
}
- /**
+ /**
* Create a builder with dimension size information for this instance. Must be one size entry per dimension,
* and, agree with the type size information when specified in the type.
* If sizes are completely specified in the type this size information is redundant.
@@ -210,20 +210,20 @@ public class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes) {
// validate
if (sizes.dimensions() != type.dimensions().size())
- throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " +
+ throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " +
"for " + type);
for (int i = 0; i < sizes.dimensions(); i++ ) {
- Optional<Integer> size = type.dimensions().get(i).size();
+ Optional<Long> size = type.dimensions().get(i).size();
if (size.isPresent() && size.get() < sizes.size(i))
- throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " +
+ throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " +
sizes.size(i) +
" but cannot be larger than " + size.get() + " in " + type);
}
-
+
return new BoundBuilder(type, sizes);
}
- public abstract Builder cell(double value, int ... indexes);
+ public abstract Builder cell(double value, long ... indexes);
@Override
public TensorType type() { return type; }
@@ -232,7 +232,7 @@ public class IndexedTensor implements Tensor {
public abstract IndexedTensor build();
}
-
+
/** A bound builder can create the double array directly */
public static class BoundBuilder extends Builder {
@@ -255,15 +255,15 @@ public class IndexedTensor implements Tensor {
if ( sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
- values = new double[sizes.totalSize()];
+ values = new double[(int)sizes.totalSize()];
}
-
+
@Override
- public BoundBuilder cell(double value, int ... indexes) {
- values[toValueIndex(indexes, sizes)] = value;
+ public BoundBuilder cell(double value, long ... indexes) {
+ values[(int)toValueIndex(indexes, sizes)] = value;
return this;
}
-
+
@Override
public CellBuilder cell() {
return new CellBuilder(type, this);
@@ -271,7 +271,7 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(TensorAddress address, double value) {
- values[toValueIndex(address, sizes)] = value;
+ values[(int)toValueIndex(address, sizes)] = value;
return this;
}
@@ -286,21 +286,21 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(Cell cell, double value) {
- int directIndex = cell.getDirectIndex();
+ long directIndex = cell.getDirectIndex();
if (directIndex >= 0) // optimization
- values[directIndex] = value;
+ values[(int)directIndex] = value;
else
super.cell(cell, value);
return this;
}
- /**
- * Set a cell value by the index in the internal layout of this cell.
+ /**
+ * Set a cell value by the index in the internal layout of this cell.
* This requires knowledge of the internal layout of cells in this implementation, and should therefore
* probably not be used (but when it can be used it is fast).
*/
- public void cellByDirectIndex(int index, double value) {
- values[index] = value;
+ public void cellByDirectIndex(long index, double value) {
+ values[(int)index] = value;
}
}
@@ -326,13 +326,13 @@ public class IndexedTensor implements Tensor {
return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) });
DimensionSizes dimensionSizes = findDimensionSizes(firstDimension);
- double[] values = new double[dimensionSizes.totalSize()];
+ double[] values = new double[(int)dimensionSizes.totalSize()];
fillValues(0, 0, firstDimension, dimensionSizes, values);
return new IndexedTensor(type, dimensionSizes, values);
}
-
+
private DimensionSizes findDimensionSizes(List<Object> firstDimension) {
- List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size());
+ List<Long> dimensionSizeList = new ArrayList<>(type.dimensions().size());
findDimensionSizes(0, dimensionSizeList, firstDimension);
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct
for (int i = 0; i < b.dimensions(); i++) {
@@ -343,33 +343,33 @@ public class IndexedTensor implements Tensor {
}
@SuppressWarnings("unchecked")
- private void findDimensionSizes(int currentDimensionIndex, List<Integer> dimensionSizes, List<Object> currentDimension) {
+ private void findDimensionSizes(int currentDimensionIndex, List<Long> dimensionSizes, List<Object> currentDimension) {
if (currentDimensionIndex == dimensionSizes.size())
- dimensionSizes.add(currentDimension.size());
+ dimensionSizes.add((long)currentDimension.size());
else if (dimensionSizes.get(currentDimensionIndex) != currentDimension.size())
- throw new IllegalArgumentException("Missing values in dimension " +
+ throw new IllegalArgumentException("Missing values in dimension " +
type.dimensions().get(currentDimensionIndex) + " in " + type);
-
+
for (Object value : currentDimension)
if (value instanceof List)
findDimensionSizes(currentDimensionIndex + 1, dimensionSizes, (List<Object>)value);
}
@SuppressWarnings("unchecked")
- private void fillValues(int currentDimensionIndex, int offset, List<Object> currentDimension,
+ private void fillValues(int currentDimensionIndex, long offset, List<Object> currentDimension,
DimensionSizes sizes, double[] values) {
if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension
- for (int i = 0; i < currentDimension.size(); i++)
+ for (long i = 0; i < currentDimension.size(); i++)
fillValues(currentDimensionIndex + 1,
offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i,
- (List<Object>) currentDimension.get(i), sizes, values);
+ (List<Object>) currentDimension.get((int)i), sizes, values);
} else { // last dimension - fill values
- for (int i = 0; i < currentDimension.size(); i++) {
- values[offset + i] = nullAsZero((Double)currentDimension.get(i)); // fill missing values as zero
+ for (long i = 0; i < currentDimension.size(); i++) {
+ values[(int)(offset + i)] = nullAsZero((Double)currentDimension.get((int)i)); // fill missing values as zero
}
}
}
-
+
private double nullAsZero(Double value) {
if (value == null) return 0;
return value;
@@ -382,9 +382,9 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(TensorAddress address, double value) {
- int[] indexes = new int[address.size()];
+ long[] indexes = new long[address.size()];
for (int i = 0; i < address.size(); i++) {
- indexes[i] = address.intLabel(i);
+ indexes[i] = address.numericLabel(i);
}
cell(value, indexes);
return this;
@@ -399,7 +399,7 @@ public class IndexedTensor implements Tensor {
*/
@SuppressWarnings("unchecked")
@Override
- public Builder cell(double value, int... indexes) {
+ public Builder cell(double value, long... indexes) {
if (indexes.length != type.dimensions().size())
throw new IllegalArgumentException("Wrong number of indexes (" + indexes.length + ") for " + type);
@@ -414,27 +414,27 @@ public class IndexedTensor implements Tensor {
for (int dimensionIndex = 0; dimensionIndex < indexes.length; dimensionIndex++) {
ensureCapacity(indexes[dimensionIndex], currentValues);
if (dimensionIndex == indexes.length - 1) { // last dimension
- currentValues.set(indexes[dimensionIndex], value);
+ currentValues.set((int)indexes[dimensionIndex], value);
} else {
- if (currentValues.get(indexes[dimensionIndex]) == null)
- currentValues.set(indexes[dimensionIndex], new ArrayList<>());
- currentValues = (List<Object>) currentValues.get(indexes[dimensionIndex]);
+ if (currentValues.get((int)indexes[dimensionIndex]) == null)
+ currentValues.set((int)indexes[dimensionIndex], new ArrayList<>());
+ currentValues = (List<Object>) currentValues.get((int)indexes[dimensionIndex]);
}
}
return this;
}
/** Fill the given list with nulls if necessary to make sure it has a (possibly null) value at the given index */
- private void ensureCapacity(int index, List<Object> list) {
+ private void ensureCapacity(long index, List<Object> list) {
while (list.size() <= index)
list.add(list.size(), null);
}
}
-
+
private final class CellIterator implements Iterator<Cell> {
- private int count = 0;
+ private long count = 0;
private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
private final LazyCell reusedCell = new LazyCell(indexes, Double.NaN);
@@ -451,12 +451,12 @@ public class IndexedTensor implements Tensor {
reusedCell.value = get(indexes.toSourceValueIndex());
return reusedCell;
}
-
+
}
private final class ValueIterator implements Iterator<Double> {
- private int count = 0;
+ private long count = 0;
@Override
public boolean hasNext() {
@@ -466,7 +466,7 @@ public class IndexedTensor implements Tensor {
@Override
public Double next() {
try {
- return values[count++];
+ return values[(int)count++];
}
catch (IndexOutOfBoundsException e) {
throw new NoSuchElementException("No element at position " + count);
@@ -474,25 +474,25 @@ public class IndexedTensor implements Tensor {
}
}
-
+
private final class SuperspaceIterator implements Iterator<SubspaceIterator> {
private final Indexes superindexes;
- /** Those indexes this should iterate over */
+ /** The indexes this should iterate over */
private final List<Integer> subdimensionIndexes;
-
- /**
+
+ /**
* The sizes of the space we'll return values of, one value for each dimension of this tensor,
- * which may be equal to or smaller than the sizes of this tensor
+ * which may be equal to or smaller than the sizes of this tensor
*/
private final DimensionSizes iterateSizes;
- private int count = 0;
-
+ private long count = 0;
+
private SuperspaceIterator(Set<String> superdimensionNames, DimensionSizes iterateSizes) {
this.iterateSizes = iterateSizes;
-
+
List<Integer> superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator
subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length)
for (int i = type.dimensions().size() - 1; i >= 0; i-- ) { // iterate inner dimensions first
@@ -501,10 +501,10 @@ public class IndexedTensor implements Tensor {
else
subdimensionIndexes.add(i);
}
-
+
superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, superdimensionIndexes);
}
-
+
@Override
public boolean hasNext() {
return count < superindexes.size();
@@ -527,60 +527,60 @@ public class IndexedTensor implements Tensor {
*/
public final class SubspaceIterator implements Iterator<Tensor.Cell> {
- /**
+ /**
* This iterator will iterate over the given dimensions, in the order given
* (the first dimension index given is incremented to exhaustion first (i.e is etc.).
* This may be any subset of the dimensions given by address and dimensionSizes.
*/
private final List<Integer> iterateDimensions;
- private final int[] address;
+ private final long[] address;
private final DimensionSizes iterateSizes;
private Indexes indexes;
- private int count = 0;
-
+ private long count = 0;
+
/** A lazy cell for reuse */
private final LazyCell reusedCell;
-
- /**
+
+ /**
* Creates a new subspace iterator
- *
+ *
* @param iterateDimensions the dimensions to iterate over, given as indexes in the dimension order of the
* type of the tensor this iterates over. This iterator will iterate over these
- * dimensions to exhaustion in the order given (the first dimension index given is
+ * dimensions to exhaustion in the order given (the first dimension index given is
* incremented to exhaustion first (i.e is etc.), while other dimensions will be held
* at a constant position.
* This may be any subset of the dimensions given by address and dimensionSizes.
* This is treated as immutable.
- * @param address the address of the first cell of this subspace.
+ * @param address the address of the first cell of this subspace.
*/
- private SubspaceIterator(List<Integer> iterateDimensions, int[] address, DimensionSizes iterateSizes) {
+ private SubspaceIterator(List<Integer> iterateDimensions, long[] address, DimensionSizes iterateSizes) {
this.iterateDimensions = iterateDimensions;
this.address = address;
this.iterateSizes = iterateSizes;
this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address);
reusedCell = new LazyCell(indexes, Double.NaN);
}
-
+
/** Returns the total number of cells in this subspace */
- public int size() {
+ public long size() {
return indexes.size();
}
-
+
/** Returns the address of the cell this currently points to (which may be an invalid position) */
public TensorAddress address() { return indexes.toAddress(); }
-
+
/** Rewind this iterator to the first element */
- public void reset() {
+ public void reset() {
this.count = 0;
- this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address);
+ this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address);
}
-
+
@Override
public boolean hasNext() {
- return count < indexes.size();
+ return count < indexes.size();
}
-
+
/** Returns the next cell, which is valid until next() is called again */
@Override
public Cell next() {
@@ -605,21 +605,21 @@ public class IndexedTensor implements Tensor {
}
@Override
- int getDirectIndex() { return indexes.toIterationValueIndex(); }
+ long getDirectIndex() { return indexes.toIterationValueIndex(); }
@Override
public TensorAddress getKey() {
return indexes.toAddress();
}
-
+
@Override
public Double getValue() { return value; }
}
// TODO: Make dimensionSizes a class
-
- /**
+
+ /**
* An array of indexes into this tensor which are able to find the next index in the value order.
* next() can be called once per element in the dimensions we iterate over. It must be called once
* before accessing the first position.
@@ -630,8 +630,8 @@ public class IndexedTensor implements Tensor {
private final DimensionSizes iterationSizes;
- protected final int[] indexes;
-
+ protected final long[] indexes;
+
public static Indexes of(DimensionSizes sizes) {
return of(sizes, sizes);
}
@@ -640,7 +640,7 @@ public class IndexedTensor implements Tensor {
return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()));
}
- private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int size) {
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long size) {
return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size);
}
@@ -648,15 +648,15 @@ public class IndexedTensor implements Tensor {
return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions));
}
- private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int size) {
- return of(sourceSizes, iterateSizes, iterateDimensions, new int[iterateSizes.dimensions()], size);
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long size) {
+ return of(sourceSizes, iterateSizes, iterateDimensions, new long[iterateSizes.dimensions()], size);
}
- private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes) {
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long[] initialIndexes) {
return of(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, computeSize(iterateSizes, iterateDimensions));
}
- private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) {
if (size == 0) {
return new EmptyIndexes(sourceSizes, iterateSizes, initialIndexes); // we're told explicitly there are truly no values available
}
@@ -676,22 +676,22 @@ public class IndexedTensor implements Tensor {
return new MultiDimensionIndexes(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, size);
}
}
-
+
private static List<Integer> completeIterationOrder(int length) {
List<Integer> iterationDimensions = new ArrayList<>(length);
for (int i = 0; i < length; i++)
iterationDimensions.add(length - 1 - i);
return iterationDimensions;
}
-
- private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, int[] indexes) {
+
+ private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, long[] indexes) {
this.sourceSizes = sourceSizes;
this.iterationSizes = iterationSizes;
this.indexes = indexes;
}
- private static int computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) {
- int size = 1;
+ private static long computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) {
+ long size = 1;
for (int iterateDimension : iterateDimensions)
size *= sizes.size(iterateDimension);
return size;
@@ -702,25 +702,25 @@ public class IndexedTensor implements Tensor {
return TensorAddress.of(indexes);
}
- public int[] indexesCopy() {
+ public long[] indexesCopy() {
return Arrays.copyOf(indexes, indexes.length);
}
/** Returns a copy of the indexes of this which must not be modified */
- public int[] indexesForReading() { return indexes; }
-
- int toSourceValueIndex() {
- return IndexedTensor.toValueIndex(indexes, sourceSizes);
+ public long[] indexesForReading() { return indexes; }
+
+ long toSourceValueIndex() {
+ return IndexedTensor.toValueIndex(indexes, sourceSizes);
}
- int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); }
+ long toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); }
DimensionSizes dimensionSizes() { return iterationSizes; }
/** Returns an immutable list containing a copy of the indexes in this */
- public List<Integer> toList() {
- ImmutableList.Builder<Integer> builder = new ImmutableList.Builder<>();
- for (int index : indexes)
+ public List<Long> toList() {
+ ImmutableList.Builder<Long> builder = new ImmutableList.Builder<>();
+ for (long index : indexes)
builder.add(index);
return builder.build();
}
@@ -729,21 +729,21 @@ public class IndexedTensor implements Tensor {
public String toString() {
return "indexes " + Arrays.toString(indexes);
}
-
- public abstract int size();
-
+
+ public abstract long size();
+
public abstract void next();
}
private final static class EmptyIndexes extends Indexes {
- private EmptyIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) {
+ private EmptyIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) {
super(sourceSizes, iterateSizes, indexes);
}
@Override
- public int size() { return 0; }
+ public long size() { return 0; }
@Override
public void next() {}
@@ -752,43 +752,43 @@ public class IndexedTensor implements Tensor {
private final static class SingleValueIndexes extends Indexes {
- private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) {
+ private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) {
super(sourceSizes, iterateSizes, indexes);
}
@Override
- public int size() { return 1; }
+ public long size() { return 1; }
@Override
public void next() {}
}
-
+
private static class MultiDimensionIndexes extends Indexes {
- private final int size;
+ private final long size;
private final List<Integer> iterateDimensions;
-
- private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+
+ private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) {
super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimensions = iterateDimensions;
this.size = size;
-
+
// Initialize to the (virtual) position before the first cell
indexes[iterateDimensions.get(0)]--;
}
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@Override
- public int size() {
+ public long size() {
return size;
}
- /**
- * Advances this to the next cell in the standard indexed tensor cell order.
- * The first call to this will put it at the first position.
- *
+ /**
+ * Advances this to the next cell in the standard indexed tensor cell order.
+ * The first call to this will put it at the first position.
+ *
* @throws RuntimeException if this is called more times than its size
*/
@Override
@@ -802,40 +802,42 @@ public class IndexedTensor implements Tensor {
}
}
-
+
/** In this case we can reuse the source index computation for the iteration index */
private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes {
- private int lastComputedSourceValueIndex = -1;
-
- private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+ private long lastComputedSourceValueIndex = -1;
+
+ private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) {
super(sizes, sizes, iterateDimensions, initialIndexes, size);
}
- int toSourceValueIndex() {
+ @Override
+ long toSourceValueIndex() {
return lastComputedSourceValueIndex = super.toSourceValueIndex();
}
// NOTE: We assume the source index always gets computed first. Otherwise using this will produce a runtime exception
- int toIterationValueIndex() { return lastComputedSourceValueIndex; }
+ @Override
+ long toIterationValueIndex() { return lastComputedSourceValueIndex; }
}
/** In this case we can keep track of indexes using a step instead of using the more elaborate computation */
private final static class SingleDimensionIndexes extends Indexes {
- private final int size;
+ private final long size;
private final int iterateDimension;
-
+
/** Maintain this directly as an optimization for 1-d iteration */
- private int currentSourceValueIndex, currentIterationValueIndex;
+ private long currentSourceValueIndex, currentIterationValueIndex;
/** The iteration step in the value index space */
- private final int sourceStep, iterationStep;
+ private final long sourceStep, iterationStep;
private SingleDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes,
- int iterateDimension, int[] initialIndexes, int size) {
+ int iterateDimension, long[] initialIndexes, long size) {
super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
@@ -847,16 +849,16 @@ public class IndexedTensor implements Tensor {
currentSourceValueIndex = IndexedTensor.toValueIndex(indexes, sourceSizes);
currentIterationValueIndex = IndexedTensor.toValueIndex(indexes, iterateSizes);
}
-
+
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@Override
- public int size() {
+ public long size() {
return size;
}
/**
- * Advances this to the next cell in the standard indexed tensor cell order.
- * The first call to this will put it at the first position.
+ * Advances this to the next cell in the standard indexed tensor cell order.
+ * The first call to this will put it at the first position.
*
* @throws RuntimeException if this is called more times than its size
*/
@@ -868,28 +870,28 @@ public class IndexedTensor implements Tensor {
}
@Override
- int toSourceValueIndex() { return currentSourceValueIndex; }
+ long toSourceValueIndex() { return currentSourceValueIndex; }
@Override
- int toIterationValueIndex() { return currentIterationValueIndex; }
+ long toIterationValueIndex() { return currentIterationValueIndex; }
}
/** In this case we only need to keep track of one index */
private final static class EqualSizeSingleDimensionIndexes extends Indexes {
- private final int size;
+ private final long size;
private final int iterateDimension;
/** Maintain this directly as an optimization for 1-d iteration */
- private int currentValueIndex;
+ private long currentValueIndex;
/** The iteration step in the value index space */
- private final int step;
+ private final long step;
- private EqualSizeSingleDimensionIndexes(DimensionSizes sizes,
- int iterateDimension, int[] initialIndexes, int size) {
+ private EqualSizeSingleDimensionIndexes(DimensionSizes sizes,
+ int iterateDimension, long[] initialIndexes, long size) {
super(sizes, sizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
@@ -902,13 +904,13 @@ public class IndexedTensor implements Tensor {
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@Override
- public int size() {
+ public long size() {
return size;
}
/**
- * Advances this to the next cell in the standard indexed tensor cell order.
- * The first call to this will put it at the first position.
+ * Advances this to the next cell in the standard indexed tensor cell order.
+ * The first call to this will put it at the first position.
*
* @throws RuntimeException if this is called more times than its size
*/
@@ -919,10 +921,10 @@ public class IndexedTensor implements Tensor {
}
@Override
- int toSourceValueIndex() { return currentValueIndex; }
+ long toSourceValueIndex() { return currentValueIndex; }
@Override
- int toIterationValueIndex() { return currentValueIndex; }
+ long toIterationValueIndex() { return currentValueIndex; }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index 618bff0caae..15993072c37 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -27,9 +27,9 @@ public class MappedTensor implements Tensor {
@Override
public TensorType type() { return type; }
-
+
@Override
- public int size() { return cells.size(); }
+ public long size() { return cells.size(); }
@Override
public double get(TensorAddress address) { return cells.getOrDefault(address, Double.NaN); }
@@ -56,16 +56,16 @@ public class MappedTensor implements Tensor {
}
public static class Builder implements Tensor.Builder {
-
+
private final TensorType type;
private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
-
+
public static Builder of(TensorType type) { return new Builder(type); }
private Builder(TensorType type) {
this.type = type;
}
-
+
public CellBuilder cell() {
return new CellBuilder(type, this);
}
@@ -80,7 +80,7 @@ public class MappedTensor implements Tensor {
}
@Override
- public Builder cell(double value, int... labels) {
+ public Builder cell(double value, long... labels) {
cells.put(TensorAddress.of(labels), value);
return this;
}
@@ -89,24 +89,24 @@ public class MappedTensor implements Tensor {
public MappedTensor build() {
return new MappedTensor(type, cells.build());
}
-
+
}
private static class CellIteratorAdaptor implements Iterator<Cell> {
private final Iterator<Map.Entry<TensorAddress, Double>> adaptedIterator;
-
+
private CellIteratorAdaptor(Iterator<Map.Entry<TensorAddress, Double>> adaptedIterator) {
this.adaptedIterator = adaptedIterator;
}
-
+
@Override
public boolean hasNext() { return adaptedIterator.hasNext(); }
@Override
public Cell next() {
Map.Entry<TensorAddress, Double> entry = adaptedIterator.next();
- return new Cell(entry.getKey(), entry.getValue());
+ return new Cell(entry.getKey(), entry.getValue());
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 79bb27fcd1b..0c9ed769c0d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -47,13 +47,13 @@ public class MixedTensor implements Tensor {
/** Returns the size of the tensor measured in number of cells */
@Override
- public int size() { return cells.size(); }
+ public long size() { return cells.size(); }
/** Returns the value at the given address */
@Override
public double get(TensorAddress address) {
- int cellIndex = index.indexOf(address);
- Cell cell = cells.get(cellIndex);
+ long cellIndex = index.indexOf(address);
+ Cell cell = cells.get((int)cellIndex);
if (!address.equals(cell.getKey())) {
throw new IllegalStateException("Unable to find correct cell by direct index.");
}
@@ -113,11 +113,11 @@ public class MixedTensor implements Tensor {
}
/** Returns the size of dense subspaces */
- public int denseSubspaceSize() {
+ public long denseSubspaceSize() {
return index.denseSubspaceSize();
}
-
+
/**
* Base class for building mixed tensors.
*/
@@ -148,7 +148,7 @@ public class MixedTensor implements Tensor {
}
@Override
- public Tensor.Builder cell(double value, int... labels) {
+ public Tensor.Builder cell(double value, long... labels) {
throw new UnsupportedOperationException("Not implemented.");
}
@@ -179,13 +179,13 @@ public class MixedTensor implements Tensor {
index = indexBuilder.index();
}
- public int denseSubspaceSize() {
+ public long denseSubspaceSize() {
return index.denseSubspaceSize();
}
private double[] denseSubspace(TensorAddress sparsePartial) {
if (!denseSubspaceMap.containsKey(sparsePartial)) {
- denseSubspaceMap.put(sparsePartial, new double[denseSubspaceSize()]);
+ denseSubspaceMap.put(sparsePartial, new double[(int)denseSubspaceSize()]);
}
return denseSubspaceMap.get(sparsePartial);
}
@@ -193,21 +193,21 @@ public class MixedTensor implements Tensor {
@Override
public Tensor.Builder cell(TensorAddress address, double value) {
TensorAddress sparsePart = index.sparsePartialAddress(address);
- int denseOffset = index.denseOffset(address);
+ long denseOffset = index.denseOffset(address);
double[] denseSubspace = denseSubspace(sparsePart);
- denseSubspace[denseOffset] = value;
+ denseSubspace[(int)denseOffset] = value;
return this;
}
public Tensor.Builder block(TensorAddress sparsePart, double[] values) {
double[] denseSubspace = denseSubspace(sparsePart);
- System.arraycopy(values, 0, denseSubspace, 0, denseSubspaceSize());
+ System.arraycopy(values, 0, denseSubspace, 0, (int)denseSubspaceSize());
return this;
}
@Override
public MixedTensor build() {
- int count = 0;
+ long count = 0;
ImmutableList.Builder<Cell> builder = new ImmutableList.Builder<>();
for (Map.Entry<TensorAddress, double[]> entry : denseSubspaceMap.entrySet()) {
@@ -215,9 +215,9 @@ public class MixedTensor implements Tensor {
indexBuilder.put(sparsePart, count);
double[] denseSubspace = entry.getValue();
- for (int offset = 0; offset < denseSubspace.length; ++offset) {
+ for (long offset = 0; offset < denseSubspace.length; ++offset) {
TensorAddress cellAddress = index.addressOf(sparsePart, offset);
- double value = denseSubspace[offset];
+ double value = denseSubspace[(int)offset];
builder.add(new Cell(cellAddress, value));
count++;
}
@@ -239,12 +239,12 @@ public class MixedTensor implements Tensor {
public static class UnboundBuilder extends Builder {
private Map<TensorAddress, Double> cells;
- private final int[] dimensionBounds;
+ private final long[] dimensionBounds;
private UnboundBuilder(TensorType type) {
super(type);
cells = new HashMap<>();
- dimensionBounds = new int[type.dimensions().size()];
+ dimensionBounds = new long[type.dimensions().size()];
}
@Override
@@ -268,7 +268,7 @@ public class MixedTensor implements Tensor {
for (int i = 0; i < type.dimensions().size(); ++i) {
TensorType.Dimension dimension = type.dimensions().get(i);
if (dimension.isIndexed()) {
- dimensionBounds[i] = Math.max(address.intLabel(i), dimensionBounds[i]);
+ dimensionBounds[i] = Math.max(address.numericLabel(i), dimensionBounds[i]);
}
}
}
@@ -280,13 +280,13 @@ public class MixedTensor implements Tensor {
if (!dimension.isIndexed()) {
typeBuilder.mapped(dimension.name());
} else {
- int size = dimension.size().orElse(dimensionBounds[i] + 1);
+ long size = dimension.size().orElse(dimensionBounds[i] + 1);
typeBuilder.indexed(dimension.name(), size);
}
}
return typeBuilder.build();
}
-
+
}
/**
@@ -303,8 +303,8 @@ public class MixedTensor implements Tensor {
private final List<TensorType.Dimension> mappedDimensions;
private final List<TensorType.Dimension> indexedDimensions;
- private ImmutableMap<TensorAddress, Integer> sparseMap;
- private int denseSubspaceSize = -1;
+ private ImmutableMap<TensorAddress, Long> sparseMap;
+ private long denseSubspaceSize = -1;
private Index(TensorType type) {
this.type = type;
@@ -314,26 +314,27 @@ public class MixedTensor implements Tensor {
this.denseType = createPartialType(indexedDimensions);
}
- public int indexOf(TensorAddress address) {
+ public long indexOf(TensorAddress address) {
TensorAddress sparsePart = sparsePartialAddress(address);
- if (!sparseMap.containsKey(sparsePart)) {
+ if ( ! sparseMap.containsKey(sparsePart)) {
throw new IllegalArgumentException("Address not found");
}
- int base = sparseMap.get(sparsePart);
- int offset = denseOffset(address);
+ long base = sparseMap.get(sparsePart);
+ long offset = denseOffset(address);
return base + offset;
}
public static class Builder {
+
private final Index index;
- private final ImmutableMap.Builder<TensorAddress, Integer> builder;
+ private final ImmutableMap.Builder<TensorAddress, Long> builder;
public Builder(TensorType type) {
index = new Index(type);
builder = new ImmutableMap.Builder<>();
}
- public void put(TensorAddress address, int index) {
+ public void put(TensorAddress address, long index) {
builder.put(address, index);
}
@@ -347,7 +348,7 @@ public class MixedTensor implements Tensor {
}
}
- public int denseSubspaceSize() {
+ public long denseSubspaceSize() {
if (denseSubspaceSize == -1) {
denseSubspaceSize = 1;
for (int i = 0; i < type.dimensions().size(); ++i) {
@@ -360,7 +361,7 @@ public class MixedTensor implements Tensor {
}
return denseSubspaceSize;
}
-
+
private TensorAddress sparsePartialAddress(TensorAddress address) {
if (type.dimensions().size() != address.size()) {
throw new IllegalArgumentException("Tensor type and address are not of same size.");
@@ -375,13 +376,13 @@ public class MixedTensor implements Tensor {
return builder.build();
}
- private int denseOffset(TensorAddress address) {
- int innerSize = 1;
- int offset = 0;
+ private long denseOffset(TensorAddress address) {
+ long innerSize = 1;
+ long offset = 0;
for (int i = type.dimensions().size(); --i >= 0; ) {
TensorType.Dimension dimension = type.dimensions().get(i);
if (dimension.isIndexed()) {
- int label = address.intLabel(i);
+ long label = address.numericLabel(i);
offset += label * innerSize;
innerSize *= dimension.size().orElseThrow(() ->
new IllegalArgumentException("Unknown size of indexed dimension."));
@@ -390,18 +391,18 @@ public class MixedTensor implements Tensor {
return offset;
}
- private TensorAddress denseOffsetToAddress(int denseOffset) {
+ private TensorAddress denseOffsetToAddress(long denseOffset) {
if (denseOffset < 0 || denseOffset > denseSubspaceSize) {
throw new IllegalArgumentException("Offset out of bounds");
}
- int restSize = denseOffset;
- int innerSize = denseSubspaceSize;
- int[] labels = new int[indexedDimensions.size()];
+ long restSize = denseOffset;
+ long innerSize = denseSubspaceSize;
+ long[] labels = new long[indexedDimensions.size()];
for (int i = 0; i < labels.length; ++i) {
TensorType.Dimension dimension = indexedDimensions.get(i);
- int dimensionSize = dimension.size().orElseThrow(() ->
+ long dimensionSize = dimension.size().orElseThrow(() ->
new IllegalArgumentException("Unknown size of indexed dimension."));
innerSize /= dimensionSize;
@@ -411,7 +412,7 @@ public class MixedTensor implements Tensor {
return TensorAddress.of(labels);
}
- private TensorAddress addressOf(TensorAddress sparsePart, int denseOffset) {
+ private TensorAddress addressOf(TensorAddress sparsePart, long denseOffset) {
TensorAddress densePart = denseOffsetToAddress(denseOffset);
String[] labels = new String[type.dimensions().size()];
int mappedIndex = 0;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
index e3398850373..23ef0772aea 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
@@ -6,11 +6,11 @@ import com.google.common.annotations.Beta;
/**
* An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors
* dimensions.
- *
+ *
* @author bratseth
*/
-// Implementation notes:
-// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation.
+// Implementation notes:
+// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation.
// We also avoid non-essential error checking.
// - We can add support for string labels later without breaking the API
@Beta
@@ -19,7 +19,7 @@ public class PartialAddress {
// Two arrays which contains corresponding dimension=label pairs.
// The sizes of these are always equal.
private final String[] dimensionNames;
- private final int[] labels;
+ private final long[] labels;
private PartialAddress(Builder builder) {
this.dimensionNames = builder.dimensionNames;
@@ -27,36 +27,36 @@ public class PartialAddress {
builder.dimensionNames = null; // invalidate builder to safely take over array ownership
builder.labels = null;
}
-
+
/** Returns the int label of this dimension, or -1 if no label is specified for it */
- int intLabel(String dimensionName) {
+ long numericLabel(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
return labels[i];
return -1;
}
-
+
public static class Builder {
private String[] dimensionNames;
- private int[] labels;
+ private long[] labels;
private int index = 0;
-
+
public Builder(int size) {
dimensionNames = new String[size];
- labels = new int[size];
+ labels = new long[size];
}
-
- public void add(String dimensionName, int label) {
+
+ public void add(String dimensionName, long label) {
dimensionNames[index] = dimensionName;
labels[index] = label;
index++;
}
-
+
public PartialAddress build() {
return new PartialAddress(this);
}
-
+
}
-
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 2ed211539d8..0c948f1fbee 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -52,14 +52,14 @@ import java.util.function.Function;
public interface Tensor {
// ----------------- Accessors
-
+
TensorType type();
/** Returns whether this have any cells */
default boolean isEmpty() { return size() == 0; }
/** Returns the number of cells in this */
- int size();
+ long size();
/** Returns the value of a cell, or NaN if this cell does not exist/have no value */
double get(TensorAddress address);
@@ -70,13 +70,13 @@ public interface Tensor {
/** Returns the values of this in some undefined order */
Iterator<Double> valueIterator();
- /**
+ /**
* Returns an immutable map of the cells of this in no particular order.
- * This may be expensive for some implementations - avoid when possible
+ * This may be expensive for some implementations - avoid when possible
*/
Map<TensorAddress, Double> cells();
- /**
+ /**
* Returns the value of this as a double if it has no dimensions and one value
*
* @throws IllegalStateException if this does not have zero dimensions and one value
@@ -87,9 +87,9 @@ public interface Tensor {
if (size() == 0) return Double.NaN;
return valueIterator().next();
}
-
+
// ----------------- Primitive tensor functions
-
+
default Tensor map(DoubleUnaryOperator mapper) {
return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate();
}
@@ -108,7 +108,7 @@ public interface Tensor {
}
default Tensor rename(String fromDimension, String toDimension) {
- return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension),
+ return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension),
Collections.singletonList(toDimension)).evaluate();
}
@@ -123,13 +123,13 @@ public interface Tensor {
default Tensor rename(List<String> fromDimensions, List<String> toDimensions) {
return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate();
}
-
- static Tensor generate(TensorType type, Function<List<Integer>, Double> valueSupplier) {
+
+ static Tensor generate(TensorType type, Function<List<Long>, Double> valueSupplier) {
return new Generate(type, valueSupplier).evaluate();
}
-
+
// ----------------- Composite tensor functions which have a defined primitive mapping
-
+
default Tensor l1Normalize(String dimension) {
return new L1Normalize(new ConstantTensor(this), dimension).evaluate();
}
@@ -231,7 +231,7 @@ public interface Tensor {
if (cellEntries.isEmpty()) return "{}";
return "{" + cellEntries.get(0).getValue() +"}";
}
-
+
Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey());
StringBuilder b = new StringBuilder("{");
@@ -253,7 +253,7 @@ public interface Tensor {
*/
boolean equals(Object o);
- /**
+ /**
* Implement here to make this work across implementations.
* Implementations must override equals and call this because this is an interface and cannot override equals.
*/
@@ -328,13 +328,13 @@ public interface Tensor {
@Override
public TensorAddress getKey() { return address; }
- /**
+ /**
* Returns the direct index which can be used to locate this cell, or -1 if not available.
* This is for optimizations mapping between tensors where this is possible without creating a
* TensorAddress.
*/
- int getDirectIndex() { return -1; }
-
+ long getDirectIndex() { return -1; }
+
@Override
public Double getValue() { return value; }
@@ -388,20 +388,20 @@ public interface Tensor {
/** Returns the type this is building */
TensorType type();
-
+
/** Return a cell builder */
CellBuilder cell();
/** Add a cell */
Builder cell(TensorAddress address, double value);
-
+
/** Add a cell */
- Builder cell(double value, int ... labels);
+ Builder cell(double value, long ... labels);
- /**
- * Add a cell
- *
- * @param cell a cell providing the location at which to add this cell
+ /**
+ * Add a cell
+ *
+ * @param cell a cell providing the location at which to add this cell
* @param value the value to assign to the cell
*/
default Builder cell(Cell cell, double value) {
@@ -409,12 +409,12 @@ public interface Tensor {
}
Tensor build();
-
+
class CellBuilder {
private final TensorAddress.Builder addressBuilder;
private final Tensor.Builder tensorBuilder;
-
+
CellBuilder(TensorType type, Tensor.Builder tensorBuilder) {
addressBuilder = new TensorAddress.Builder(type);
this.tensorBuilder = tensorBuilder;
@@ -425,7 +425,7 @@ public interface Tensor {
return this;
}
- public CellBuilder label(String dimension, int label) {
+ public CellBuilder label(String dimension, long label) {
return label(dimension, String.valueOf(label));
}
@@ -436,5 +436,5 @@ public interface Tensor {
}
}
-
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 7161450d5d5..38553497478 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -2,16 +2,10 @@
package com.yahoo.tensor;
import com.google.common.annotations.Beta;
-import com.google.common.collect.ImmutableList;
-import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
import java.util.Objects;
import java.util.Optional;
-import java.util.Set;
/**
* An immutable address to a tensor cell. This simply supplies a value to each dimension
@@ -26,29 +20,29 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return new StringTensorAddress(labels);
}
- public static TensorAddress of(int ... labels) {
- return new IntTensorAddress(labels);
+ public static TensorAddress of(long ... labels) {
+ return new NumericTensorAddress(labels);
}
/** Returns the number of labels in this */
public abstract int size();
-
+
/**
- * Returns the i'th label in this
- *
+ * Returns the i'th label in this
+ *
* @throws IllegalArgumentException if there is no label at this index
*/
public abstract String label(int i);
/**
- * Returns the i'th label in this as an int.
- * Prefer this if you know that this is an integer address, but not otherwise.
+ * Returns the i'th label in this as a long.
+ * Prefer this if you know that this is a numeric address, but not otherwise.
*
* @throws IllegalArgumentException if there is no label at this index
*/
- public abstract int intLabel(int i);
+ public abstract long numericLabel(int i);
- public abstract TensorAddress withLabel(int labelIndex, int label);
+ public abstract TensorAddress withLabel(int labelIndex, long label);
public final boolean isEmpty() { return size() == 0; }
@@ -102,25 +96,25 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
private StringTensorAddress(String ... labels) {
this.labels = Arrays.copyOf(labels, labels.length);
}
-
+
@Override
public int size() { return labels.length; }
-
+
@Override
public String label(int i) { return labels[i]; }
-
+
@Override
- public int intLabel(int i) {
+ public long numericLabel(int i) {
try {
- return Integer.parseInt(labels[i]);
- }
+ return Long.parseLong(labels[i]);
+ }
catch (NumberFormatException e) {
- throw new IllegalArgumentException("Expected an int label in " + this + " at position " + i);
+ throw new IllegalArgumentException("Expected a long label in " + this + " at position " + i);
}
}
-
+
@Override
- public TensorAddress withLabel(int index, int label) {
+ public TensorAddress withLabel(int index, long label) {
String[] labels = Arrays.copyOf(this.labels, this.labels.length);
labels[index] = String.valueOf(label);
return new StringTensorAddress(labels);
@@ -133,11 +127,11 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
}
- private static final class IntTensorAddress extends TensorAddress {
+ private static final class NumericTensorAddress extends TensorAddress {
- private final int[] labels;
+ private final long[] labels;
- private IntTensorAddress(int[] labels) {
+ private NumericTensorAddress(long[] labels) {
this.labels = Arrays.copyOf(labels, labels.length);
}
@@ -148,13 +142,13 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
public String label(int i) { return String.valueOf(labels[i]); }
@Override
- public int intLabel(int i) { return labels[i]; }
+ public long numericLabel(int i) { return labels[i]; }
@Override
- public TensorAddress withLabel(int index, int label) {
- int[] labels = Arrays.copyOf(this.labels, this.labels.length);
+ public TensorAddress withLabel(int index, long label) {
+ long[] labels = Arrays.copyOf(this.labels, this.labels.length);
labels[index] = label;
- return new IntTensorAddress(labels);
+ return new NumericTensorAddress(labels);
}
@Override
@@ -169,7 +163,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
private final TensorType type;
private final String[] labels;
-
+
public Builder(TensorType type) {
this(type, new String[type.dimensions().size()]);
}
@@ -193,7 +187,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
labels[labelIndex.get()] = label;
return this;
}
-
+
/** Creates a copy of this which can be modified separately */
public Builder copy() {
return new Builder(type, Arrays.copyOf(labels, labels.length));
@@ -202,7 +196,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
public TensorAddress build() {
for (int i = 0; i < labels.length; i++)
if (labels[i] == null)
- throw new IllegalArgumentException("Missing a value for dimension " +
+ throw new IllegalArgumentException("Missing a value for dimension " +
type.dimensions().get(i).name() + " for " + type);
return TensorAddress.of(labels);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index da8ab3bb0ec..9b3a9328f07 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -96,7 +96,7 @@ class TensorParser {
if (valueEnd < 0)
throw new IllegalArgumentException("A tensor string must end by '}'");
}
-
+
TensorAddress address = addressBuilder.build();
Double value = asDouble(address, s.substring(0, valueEnd).trim());
builder.cell(address, value);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index c05c35d6df3..b396f831de0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -53,14 +53,17 @@ public class TensorType {
return TensorTypeParser.fromSpec(specString);
}
+ /** Returns the number of dimensions of this: dimensions().size() */
+ public int rank() { return dimensions.size(); }
+
/** Returns an immutable list of the dimensions of this */
public List<Dimension> dimensions() { return dimensions; }
-
+
/** Returns an immutable set of the names of the dimensions of this */
public Set<String> dimensionNames() {
return dimensions.stream().map(Dimension::name).collect(Collectors.toSet());
}
-
+
/** Returns the dimension with this name, or empty if not present */
public Optional<Dimension> dimension(String name) {
return indexOfDimension(name).map(i -> dimensions.get(i));
@@ -74,7 +77,7 @@ public class TensorType {
return Optional.empty();
}
- /**
+ /**
* Returns whether this type can be assigned to the given type,
* i.e if the given type is a generalization of this type.
*/
@@ -128,15 +131,15 @@ public class TensorType {
private final String name;
- private Dimension(String name) {
+ private Dimension(String name) {
Objects.requireNonNull(name, "A tensor name cannot be null");
- this.name = name;
+ this.name = name;
}
public final String name() { return name; }
/** Returns the size of this dimension if it is bound, empty otherwise */
- public abstract Optional<Integer> size();
+ public abstract Optional<Long> size();
public abstract Type type();
@@ -146,7 +149,7 @@ public class TensorType {
/** Returns true if this is an indexed bound or unboun type */
public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; }
- /**
+ /**
* Returns the dimension resulting from combining two dimensions having the same name but possibly different
* types. This works by degrading to the type making the fewer promises.
* [N] + [M] = [min(N, M)]
@@ -165,7 +168,7 @@ public class TensorType {
IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get();
return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
}
-
+
@Override
public abstract String toString();
@@ -175,36 +178,38 @@ public class TensorType {
if (other == null || getClass() != other.getClass()) return false;
return name.equals(((Dimension)other).name);
}
-
+
@Override
public int hashCode() {
return name.hashCode();
}
-
+
@Override
public int compareTo(Dimension other) {
return this.name.compareTo(other.name);
}
-
- public static Dimension indexed(String name, int size) {
+
+ public static Dimension indexed(String name, long size) {
return new IndexedBoundDimension(name, size);
}
-
+
}
public static class IndexedBoundDimension extends TensorType.Dimension {
- private final Integer size;
+ private final Long size;
- private IndexedBoundDimension(String name, int size) {
+ private IndexedBoundDimension(String name, long size) {
super(name);
if (size < 1)
throw new IllegalArgumentException("Size of bound dimension '" + name + "' must be at least 1");
+ if (size > Integer.MAX_VALUE)
+ throw new IllegalArgumentException("Size of bound dimension '" + name + "' cannot be larger than " + Integer.MAX_VALUE);
this.size = size;
}
@Override
- public Optional<Integer> size() { return Optional.of(size); }
+ public Optional<Long> size() { return Optional.of(size); }
@Override
public Type type() { return Type.indexedBound; }
@@ -245,7 +250,7 @@ public class TensorType {
}
@Override
- public Optional<Integer> size() { return Optional.empty(); }
+ public Optional<Long> size() { return Optional.empty(); }
@Override
public Type type() { return Type.indexedUnbound; }
@@ -266,7 +271,7 @@ public class TensorType {
}
@Override
- public Optional<Integer> size() { return Optional.empty(); }
+ public Optional<Long> size() { return Optional.empty(); }
@Override
public Type type() { return Type.mapped; }
@@ -289,9 +294,9 @@ public class TensorType {
public Builder() {
}
- /**
- * Creates a builder containing a combination of the dimensions of the given types
- *
+ /**
+ * Creates a builder containing a combination of the dimensions of the given types
+ *
* If the same dimension is indexed with different size restrictions the largest size will be used.
* If it is size restricted in one argument but not the other it will not be size restricted.
* If it is indexed in one and mapped in the other it will become mapped.
@@ -325,9 +330,12 @@ public class TensorType {
}
}
- /**
+ /** Returns the current number of dimensions in this */
+ public int rank() { return dimensions.size(); }
+
+ /**
* Adds a new dimension to this
- *
+ *
* @throws IllegalArgumentException if the dimension is already present
*/
private Builder add(Dimension dimension) {
@@ -346,16 +354,16 @@ public class TensorType {
return this;
}
- /**
+ /**
* Adds a bound indexed dimension to this
*
* @throws IllegalArgumentException if the dimension is already present
*/
- public Builder indexed(String name, int size) { return add(new IndexedBoundDimension(name, size)); }
+ public Builder indexed(String name, long size) { return add(new IndexedBoundDimension(name, size)); }
/**
* Adds an unbound indexed dimension to this
- *
+ *
* @throws IllegalArgumentException if the dimension is already present
*/
public Builder indexed(String name) {
@@ -375,7 +383,7 @@ public class TensorType {
public Builder dimension(Dimension dimension) {
return add(dimension);
}
-
+
/** Returns the given dimension, or empty if none is present */
public Optional<Dimension> getDimension(String dimension) {
return Optional.ofNullable(dimensions.get(dimension));
@@ -393,7 +401,7 @@ public class TensorType {
public TensorType build() {
return new TensorType(dimensions.values());
}
-
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
index 84caca78fb2..3db661f8a23 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
@@ -2,16 +2,17 @@
package com.yahoo.tensor.evaluation;
import com.google.common.annotations.Beta;
-
-import java.util.HashMap;
+import com.yahoo.tensor.Tensor;
/**
* An evaluation context which is passed down to all nested functions during evaluation.
- * The default context is empty to allow various evaluation frameworks to support their own implementation.
- *
+ *
* @author bratseth
*/
@Beta
public interface EvaluationContext {
+ /** Returns the tensor bound to this name, or null if none */
+ Tensor getTensor(String name);
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
index cf704c15f4f..db8a66a5fa2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
@@ -18,7 +18,7 @@ public class MapEvaluationContext implements EvaluationContext {
public void put(String name, Tensor tensor) { bindings.put(name, tensor); }
- /** Returns the tensor bound to this name, or null if none */
- public Tensor get(String name) { return bindings.get(name); }
+ @Override
+ public Tensor getTensor(String name) { return bindings.get(name); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index 8ade181bdb7..1f6ad050368 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -12,18 +12,18 @@ import java.util.List;
/**
* A tensor variable name which resolves to a tensor in the context at evaluation time
- *
+ *
* @author bratseth
*/
@Beta
public class VariableTensor extends PrimitiveTensorFunction {
private final String name;
-
+
public VariableTensor(String name) {
this.name = name;
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -35,7 +35,7 @@ public class VariableTensor extends PrimitiveTensorFunction {
@Override
public Tensor evaluate(EvaluationContext context) {
- return ((MapEvaluationContext)context).get(name);
+ return context.getTensor(name);
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 8f4dbf014a7..191c7988443 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -8,7 +8,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
/**
* A composite tensor function is a tensor function which can be expressed (less tersely)
* as a tree of primitive tensor functions.
- *
+ *
* @author bratseth
*/
@Beta
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 1dbb94fdb20..d4affe0ef9b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -15,7 +15,7 @@ import java.util.stream.Collectors;
/**
* Concatenation of two tensors along an (indexed) dimension
- *
+ *
* @author bratseth
*/
@Beta
@@ -67,15 +67,15 @@ public class Concat extends PrimitiveTensorFunction {
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
- int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
+ long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
int[] aToIndexes = mapIndexes(a.type(), concatType);
int[] bToIndexes = mapIndexes(b.type(), concatType);
concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
return builder.build();
}
-
- private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType,
+
+ private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) {
@@ -112,7 +112,7 @@ public class Concat extends PrimitiveTensorFunction {
Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build();
return tensor.multiply(unitTensor);
}
-
+
}
/** Returns the type resulting from concatenating a and b */
@@ -129,8 +129,8 @@ public class Concat extends PrimitiveTensorFunction {
DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
for (int i = 0; i < concatSizes.dimensions(); i++) {
String currentDimension = concatType.dimensions().get(i).name();
- int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0);
- int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0);
+ long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L);
+ long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L);
if (currentDimension.equals(concatDimension))
concatSizes.set(i, aSize + bSize);
else if (aSize != 0 && bSize != 0 && aSize!=bSize )
@@ -144,12 +144,12 @@ public class Concat extends PrimitiveTensorFunction {
/**
* Combine two addresses, adding the offset to the concat dimension
*
- * @return the combined address or null if the addresses are incompatible
+ * @return the combined address or null if the addresses are incompatible
* (in some other dimension than the concat dimension)
*/
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
- TensorType concatType, int concatOffset, String concatDimension) {
- int[] combinedLabels = new int[concatType.dimensions().size()];
+ TensorType concatType, long concatOffset, String concatDimension) {
+ long[] combinedLabels = new long[concatType.dimensions().size()];
Arrays.fill(combinedLabels, -1);
int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
@@ -161,7 +161,7 @@ public class Concat extends PrimitiveTensorFunction {
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
+ * That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
@@ -179,15 +179,15 @@ public class Concat extends PrimitiveTensorFunction {
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
- private boolean mapContent(TensorAddress from, int[] to, int[] indexMap, int concatDimension, int concatOffset) {
+ private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) {
for (int i = 0; i < from.size(); i++) {
int toIndex = indexMap[i];
if (concatDimension == toIndex) {
- to[toIndex] = from.intLabel(i) + concatOffset;
+ to[toIndex] = from.numericLabel(i) + concatOffset;
}
else {
- if (to[toIndex] != -1 && to[toIndex] != from.intLabel(i)) return false;
- to[toIndex] = from.intLabel(i);
+ if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false;
+ to[toIndex] = from.numericLabel(i);
}
}
return true;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 4ac7b21ba90..14ed38718ce 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -10,18 +10,18 @@ import java.util.List;
/**
* A function which returns a constant tensor.
- *
+ *
* @author bratseth
*/
@Beta
public class ConstantTensor extends PrimitiveTensorFunction {
private final Tensor constant;
-
+
public ConstantTensor(String tensorString) {
this.constant = Tensor.from(tensorString);
}
-
+
public ConstantTensor(Tensor tensor) {
this.constant = tensor;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index bbdbd5c3df1..653be8dacf0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -11,19 +11,19 @@ import java.util.stream.Stream;
/**
* A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere.
- *
+ *
* @author bratseth
*/
public class Diag extends CompositeTensorFunction {
private final TensorType type;
- private final Function<List<Integer>, Double> diagFunction;
-
+ private final Function<List<Long>, Double> diagFunction;
+
public Diag(TensorType type) {
this.type = type;
this.diagFunction = ScalarFunctions.equal(dimensionNames().collect(Collectors.toList()));
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -43,7 +43,7 @@ public class Diag extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction;
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::name);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 6ea73b7f310..ef2770c04f5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -15,31 +15,31 @@ import java.util.function.Function;
/**
* An indexed tensor whose values are generated by a function
- *
+ *
* @author bratseth
*/
@Beta
public class Generate extends PrimitiveTensorFunction {
private final TensorType type;
- private final Function<List<Integer>, Double> generator;
+ private final Function<List<Long>, Double> generator;
/**
* Creates a generated tensor
- *
+ *
* @param type the type of the tensor
- * @param generator the function generating values from a list of ints specifying the indexes of the
+ * @param generator the function generating values from a list of numbers specifying the indexes of the
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public Generate(TensorType type, Function<List<Integer>, Double> generator) {
+ public Generate(TensorType type, Function<List<Long>, Double> generator) {
Objects.requireNonNull(type, "The argument tensor type cannot be null");
Objects.requireNonNull(generator, "The argument function cannot be null");
validateType(type);
this.type = type;
this.generator = generator;
}
-
+
private void validateType(TensorType type) {
for (TensorType.Dimension dimension : type.dimensions())
if (dimension.type() != TensorType.Dimension.Type.indexedBound)
@@ -58,7 +58,7 @@ public class Generate extends PrimitiveTensorFunction {
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
-
+
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor.Builder builder = Tensor.Builder.of(type);
@@ -69,14 +69,14 @@ public class Generate extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
private DimensionSizes dimensionSizes(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
for (int i = 0; i < b.dimensions(); i++)
b.set(i, type.dimensions().get(i).size().get());
return b.build();
}
-
+
@Override
public String toString(ToStringContext context) { return type + "(" + generator + ")"; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 8c4dbfb0acb..174a8e4c435 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -28,12 +28,12 @@ import java.util.function.DoubleBinaryOperator;
* The <i>join</i> tensor operation produces a tensor from the argument tensors containing the set of cells
* given by the cross product of the cells of the given tensors, having as values the value produced by
* applying the given combinator function on the values from the two source cells.
- *
+ *
* @author bratseth
*/
@Beta
public class Join extends PrimitiveTensorFunction {
-
+
private final TensorFunction argumentA, argumentB;
private final DoubleBinaryOperator combinator;
@@ -46,6 +46,30 @@ public class Join extends PrimitiveTensorFunction {
this.combinator = combinator;
}
+ /** Returns the type resulting from applying Join to the two given types */
+ public static TensorType outputType(TensorType a, TensorType b) {
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (int i = 0; i < a.dimensions().size(); ++i) {
+ TensorType.Dimension aDim = a.dimensions().get(i);
+ for (int j = 0; j < b.dimensions().size(); ++j) {
+ TensorType.Dimension bDim = b.dimensions().get(j);
+ if (aDim.name().equals(bDim.name())) { // include
+ if (aDim.isIndexed() && bDim.isIndexed()) {
+ if (aDim.size().isPresent() || bDim.size().isPresent())
+ typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Long.MAX_VALUE),
+ bDim.size().orElse(Long.MAX_VALUE)));
+ else
+ typeBuilder.indexed(aDim.name());
+ }
+ else {
+ typeBuilder.mapped(aDim.name());
+ }
+ }
+ }
+ }
+ return typeBuilder.build();
+ }
+
public TensorFunction argumentA() { return argumentA; }
public TensorFunction argumentB() { return argumentB; }
public DoubleBinaryOperator combinator() { return combinator; }
@@ -88,17 +112,17 @@ public class Join extends PrimitiveTensorFunction {
else
return generalJoin(a, b, joinedType);
}
-
+
private boolean hasSingleIndexedDimension(Tensor tensor) {
return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
}
-
+
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
- int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
+ long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
Iterator<Double> bIterator = b.valueIterator();
- IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedLength).build());
- for (int i = 0; i < joinedLength; i++)
+ IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build());
+ for (int i = 0; i < joinedRank; i++)
builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i);
return builder.build();
}
@@ -114,7 +138,7 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
/** Join a tensor into a superspace */
private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
@@ -126,7 +150,7 @@ public class Join extends PrimitiveTensorFunction {
private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
-
+
DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace);
IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes);
@@ -134,21 +158,21 @@ public class Join extends PrimitiveTensorFunction {
// Find dimensions which are only in the supertype
Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames());
superDimensionNames.removeAll(subspace.type().dimensionNames());
-
+
for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
joinSubspaces(subspace.valueIterator(), subspace.size(),
subspaceInSuper, subspaceInSuper.size(),
reversedArgumentOrder, builder);
}
-
+
return builder.build();
}
- private void joinSubspaces(Iterator<Double> subspace, int subspaceSize,
- Iterator<Tensor.Cell> superspace, int superspaceSize,
+ private void joinSubspaces(Iterator<Double> subspace, long subspaceSize,
+ Iterator<Tensor.Cell> superspace, long superspaceSize,
boolean reversedArgumentOrder, IndexedTensor.Builder builder) {
- int joinedLength = Math.min(subspaceSize, superspaceSize);
+ long joinedLength = Math.min(subspaceSize, superspaceSize);
if (reversedArgumentOrder) {
for (int i = 0; i < joinedLength; i++) {
Tensor.Cell supercell = superspace.next();
@@ -200,7 +224,7 @@ public class Join extends PrimitiveTensorFunction {
subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
return subspaceIndexes;
}
-
+
private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
String[] subspaceLabels = new String[subspaceIndexes.length];
for (int i = 0; i < subspaceIndexes.length; i++)
@@ -235,7 +259,7 @@ public class Join extends PrimitiveTensorFunction {
DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);
// for each combination of dimensions only in a
- for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
+ for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
IndexedTensor.SubspaceIterator aSubspace = ia.next();
// for each combination of dimensions in a which is also in b
while (aSubspace.hasNext()) {
@@ -252,15 +276,15 @@ public class Join extends PrimitiveTensorFunction {
}
}
}
-
+
private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
for (int i = 0; i < addressType.dimensions().size(); i++)
if (retainDimensions.contains(addressType.dimensions().get(i).name()))
- builder.add(addressType.dimensions().get(i).name(), address.intLabel(i));
+ builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i));
return builder.build();
}
-
+
/** Returns the sizes from the joined sizes which are present in the type argument */
private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
@@ -271,7 +295,7 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
@@ -340,7 +364,7 @@ public class Join extends PrimitiveTensorFunction {
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
+ * That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
@@ -360,7 +384,7 @@ public class Join extends PrimitiveTensorFunction {
return TensorAddress.of(joinedLabels);
}
- /**
+ /**
* Maps the content in the given list to the given array, using the given index map.
*
* @return true if the mapping was successful, false if one of the destination positions was
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index a9872bb42d8..a5e1a016a41 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -6,6 +6,7 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Collections;
@@ -32,6 +33,8 @@ public class Map extends PrimitiveTensorFunction {
this.mapper = mapper;
}
+ public static TensorType outputType(TensorType inputType) { return inputType; }
+
public TensorFunction argument() { return argument; }
public DoubleUnaryOperator mapper() { return mapper; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index bb27e937699..4071917c2b5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.TensorType;
import java.util.List;
@@ -14,13 +15,17 @@ public class Matmul extends CompositeTensorFunction {
private final TensorFunction argument1, argument2;
private final String dimension;
-
+
public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) {
this.argument1 = argument1;
this.argument2 = argument2;
this.dimension = dimension;
}
+ public static TensorType outputType(TensorType a, TensorType b, String dimension) {
+ return Join.outputType(a, b);
+ }
+
@Override
public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); }
@@ -39,7 +44,7 @@ public class Matmul extends CompositeTensorFunction {
Reduce.Aggregator.sum,
dimension);
}
-
+
@Override
public String toString(ToStringContext context) {
return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
index efb7b9e500c..b7c9a5d2342 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
@@ -8,10 +8,10 @@ import com.yahoo.tensor.Tensor;
* A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions.
* All tensor implementations must implement all primitive tensor functions.
* Primitive tensor functions are fully inspectable.
- *
+ *
* @author bratseth
*/
@Beta
public abstract class PrimitiveTensorFunction extends TensorFunction {
-
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
index 457763e97ba..958ef85d1dc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -22,11 +22,11 @@ import java.util.stream.Stream;
public class Random extends CompositeTensorFunction {
private final TensorType type;
-
+
public Random(TensorType type) {
this.type = type;
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -46,7 +46,7 @@ public class Random extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")";
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index e2b39a2048d..8e7f4e4c773 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -12,19 +12,19 @@ import java.util.stream.Stream;
/**
* A tensor generator which returns a tensor of any dimension filled with the sum of the tensor
* indexes of each position.
- *
+ *
* @author bratseth
*/
public class Range extends CompositeTensorFunction {
private final TensorType type;
- private final Function<List<Integer>, Double> rangeFunction;
-
+ private final Function<List<Long>, Double> rangeFunction;
+
public Range(TensorType type) {
this.type = type;
this.rangeFunction = ScalarFunctions.sum(dimensionNames().collect(Collectors.toList()));
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -44,7 +44,7 @@ public class Range extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction;
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index cfc78be7e0c..de9f90a5804 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -19,7 +19,7 @@ import java.util.Objects;
import java.util.Set;
/**
- * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
+ * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
* are collapsed to a single value using an aggregator function.
*
* @author bratseth
@@ -45,7 +45,7 @@ public class Reduce extends PrimitiveTensorFunction {
/**
* Creates a reduce function.
- *
+ *
* @param argument the tensor to reduce
* @param aggregator the aggregator function to use
* @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced,
@@ -61,6 +61,15 @@ public class Reduce extends PrimitiveTensorFunction {
this.dimensions = ImmutableList.copyOf(dimensions);
}
+ public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) {
+ TensorType.Builder b = new TensorType.Builder();
+ for (TensorType.Dimension dimension : inputType.dimensions()) {
+ if ( ! reduceDimensions.contains(dimension.name()))
+ b.dimension(dimension);
+ }
+ return b.build();
+ }
+
public TensorFunction argument() { return argument; }
@Override
@@ -82,7 +91,7 @@ public class Reduce extends PrimitiveTensorFunction {
public String toString(ToStringContext context) {
return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
}
-
+
private String commaSeparated(List<String> list) {
StringBuilder b = new StringBuilder();
for (String element : list)
@@ -94,7 +103,7 @@ public class Reduce extends PrimitiveTensorFunction {
public Tensor evaluate(EvaluationContext context) {
Tensor argument = this.argument.evaluate(context);
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
- throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
+ throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
dimensions + ": Not all those dimensions are present in this tensor");
// Special case: Reduce all
@@ -103,14 +112,14 @@ public class Reduce extends PrimitiveTensorFunction {
return reduceIndexedVector((IndexedTensor)argument);
else
return reduceAllGeneral(argument);
-
+
// Reduce type
TensorType.Builder builder = new TensorType.Builder();
for (TensorType.Dimension dimension : argument.type().dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
builder.dimension(dimension);
TensorType reducedType = builder.build();
-
+
// Reduce cells
Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
@@ -122,10 +131,10 @@ public class Reduce extends PrimitiveTensorFunction {
Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType);
for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet())
reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
-
+
return reducedBuilder.build();
}
-
+
private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) {
Set<Integer> indexesToRemove = new HashSet<>();
for (String dimensionToRemove : this.dimensions)
@@ -138,7 +147,7 @@ public class Reduce extends PrimitiveTensorFunction {
reducedLabels[reducedLabelIndex++] = address.label(i);
return TensorAddress.of(reducedLabels);
}
-
+
private Tensor reduceAllGeneral(Tensor argument) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); )
@@ -154,7 +163,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
private static abstract class ValueAggregator {
-
+
private static ValueAggregator ofType(Aggregator aggregator) {
switch (aggregator) {
case avg : return new AvgAggregator();
@@ -165,22 +174,22 @@ public class Reduce extends PrimitiveTensorFunction {
case min : return new MinAggregator();
default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
}
-
+
}
/** Add a new value to those aggregated by this */
public abstract void aggregate(double value);
-
+
/** Returns the value aggregated by this */
public abstract double aggregatedValue();
-
+
}
-
+
private static class AvgAggregator extends ValueAggregator {
private int valueCount = 0;
private double valueSum = 0.0;
-
+
@Override
public void aggregate(double value) {
valueCount++;
@@ -188,7 +197,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public double aggregatedValue() {
+ public double aggregatedValue() {
return valueSum / valueCount;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 6b0daf1b49a..ec9b762a41c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -3,8 +3,6 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -19,7 +17,7 @@ import java.util.Objects;
/**
* The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names.
- *
+ *
* @author bratseth
*/
@Beta
@@ -29,6 +27,10 @@ public class Rename extends PrimitiveTensorFunction {
private final List<String> fromDimensions;
private final List<String> toDimensions;
+ public Rename(TensorFunction argument, String fromDimension, String toDimension) {
+ this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension));
+ }
+
public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
@@ -42,7 +44,7 @@ public class Rename extends PrimitiveTensorFunction {
this.fromDimensions = ImmutableList.copyOf(fromDimensions);
this.toDimensions = ImmutableList.copyOf(toDimensions);
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
@@ -62,7 +64,7 @@ public class Rename extends PrimitiveTensorFunction {
Map<String, String> fromToMap = fromToMap();
TensorType renamedType = rename(tensor.type(), fromToMap);
-
+
// an array which lists the index of each label in the renamed type
int[] toIndexes = new int[tensor.type().dimensions().size()];
for (int i = 0; i < tensor.type().dimensions().size(); i++) {
@@ -70,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction {
String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName);
toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get();
}
-
+
Tensor.Builder builder = Tensor.Builder.of(renamedType);
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
@@ -86,7 +88,7 @@ public class Rename extends PrimitiveTensorFunction {
builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
return builder.build();
}
-
+
private TensorAddress rename(TensorAddress address, int[] toIndexes) {
String[] reorderedLabels = new String[toIndexes.length];
for (int i = 0; i < toIndexes.length; i++)
@@ -95,18 +97,18 @@ public class Rename extends PrimitiveTensorFunction {
}
@Override
- public String toString(ToStringContext context) {
- return "rename(" + argument.toString(context) + ", " +
+ public String toString(ToStringContext context) {
+ return "rename(" + argument.toString(context) + ", " +
toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
}
-
+
private Map<String, String> fromToMap() {
Map<String, String> map = new HashMap<>();
for (int i = 0; i < fromDimensions.size(); i++)
map.put(fromDimensions.get(i), toDimensions.get(i));
return map;
}
-
+
private String toVectorString(List<String> elements) {
if (elements.size() == 1)
return elements.get(0);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 99f79cb735a..f1dadba2a29 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -14,127 +14,112 @@ import java.util.stream.Collectors;
/**
* Factory of scalar Java functions.
* The purpose of this is to embellish anonymous functions with a runtime type
- * such that they can be inspected and will return a parseable toString.
- *
+ * such that they can be inspected and will return a parsable toString.
+ *
* @author bratseth
*/
@Beta
public class ScalarFunctions {
- public static DoubleBinaryOperator add() { return new Addition(); }
- public static DoubleBinaryOperator multiply() { return new Multiplication(); }
- public static DoubleBinaryOperator divide() { return new Division(); }
+ public static DoubleBinaryOperator add() { return new Add(); }
+ public static DoubleBinaryOperator divide() { return new Divide(); }
public static DoubleBinaryOperator equal() { return new Equal(); }
- public static DoubleUnaryOperator square() { return new Square(); }
+ public static DoubleBinaryOperator multiply() { return new Multiply(); }
+
+ public static DoubleUnaryOperator acos() { return new Acos(); }
+ public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
- public static DoubleUnaryOperator exp() { return new Exponent(); }
- public static Function<List<Integer>, Double> random() { return new Random(); }
- public static Function<List<Integer>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
- public static Function<List<Integer>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
+ public static DoubleUnaryOperator square() { return new Square(); }
+
+ public static Function<List<Long>, Double> random() { return new Random(); }
+ public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
+ public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
- public static class Addition implements DoubleBinaryOperator {
+ // Binary operators -----------------------------------------------------------------------------
+ public static class Add implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left + right; }
-
@Override
public String toString() { return "f(a,b)(a + b)"; }
+ }
+ public static class Equal implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
+ @Override
+ public String toString() { return "f(a,b)(a==b)"; }
}
- public static class Multiplication implements DoubleBinaryOperator {
+ public static class Exp implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.exp(operand); }
+ @Override
+ public String toString() { return "f(a)(exp(a))"; }
+ }
+ public static class Multiply implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left * right; }
-
@Override
public String toString() { return "f(a,b)(a * b)"; }
-
}
- public static class Division implements DoubleBinaryOperator {
-
+ public static class Divide implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left / right; }
-
@Override
public String toString() { return "f(a,b)(a / b)"; }
}
- public static class Equal implements DoubleBinaryOperator {
+ // Unary operators ------------------------------------------------------------------------------
+ public static class Acos implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
-
+ public double applyAsDouble(double operand) { return Math.acos(operand); }
@Override
- public String toString() { return "f(a,b)(a==b)"; }
- }
-
- public static class Square implements DoubleUnaryOperator {
-
- @Override
- public double applyAsDouble(double operand) { return operand * operand; }
-
- @Override
- public String toString() { return "f(a)(a * a)"; }
-
+ public String toString() { return "f(a)(acos(a))"; }
}
public static class Sqrt implements DoubleUnaryOperator {
-
@Override
public double applyAsDouble(double operand) { return Math.sqrt(operand); }
-
@Override
public String toString() { return "f(a)(sqrt(a))"; }
-
}
- public static class Exponent implements DoubleUnaryOperator {
+ public static class Square implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double operand) { return Math.exp(operand); }
+ public double applyAsDouble(double operand) { return operand * operand; }
@Override
- public String toString() { return "f(a)(exp(a))"; }
+ public String toString() { return "f(a)(a * a)"; }
}
- public static class Random implements Function<List<Integer>, Double> {
-
- @Override
- public Double apply(List<Integer> values) {
- return ThreadLocalRandom.current().nextDouble();
- }
-
- @Override
- public String toString() { return "random"; }
+ // Variable-length operators -----------------------------------------------------------------------------
- }
-
- public static class EqualElements implements Function<List<Integer>, Double> {
-
+ public static class EqualElements implements Function<List<Long>, Double> {
private final ImmutableList<String> argumentNames;
-
private EqualElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@Override
- public Double apply(List<Integer> values) {
+ public Double apply(List<Long> values) {
if (values.isEmpty()) return 1.0;
- for (Integer value : values)
+ for (Long value : values)
if ( ! value.equals(values.get(0)))
return 0.0;
return 1.0;
}
-
@Override
- public String toString() {
+ public String toString() {
if (argumentNames.size() == 0) return "1";
if (argumentNames.size() == 1) return "1";
if (argumentNames.size() == 2) return argumentNames.get(0) + "==" + argumentNames.get(1);
-
+
StringBuilder b = new StringBuilder();
for (int i = 0; i < argumentNames.size() -1; i++) {
b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")");
@@ -143,30 +128,34 @@ public class ScalarFunctions {
}
return b.toString();
}
-
}
- public static class SumElements implements Function<List<Integer>, Double> {
+ public static class Random implements Function<List<Long>, Double> {
+ @Override
+ public Double apply(List<Long> values) {
+ return ThreadLocalRandom.current().nextDouble();
+ }
+ @Override
+ public String toString() { return "random"; }
+ }
+ public static class SumElements implements Function<List<Long>, Double> {
private final ImmutableList<String> argumentNames;
-
private SumElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@Override
- public Double apply(List<Integer> values) {
- int sum = 0;
- for (Integer value : values)
+ public Double apply(List<Long> values) {
+ long sum = 0;
+ for (Long value : values)
sum += value;
return (double)sum;
}
-
@Override
public String toString() {
return argumentNames.stream().collect(Collectors.joining("+"));
}
-
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index bf279eb24d8..c856b548180 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -2,6 +2,8 @@
package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.TensorType;
import java.util.Collections;
import java.util.List;
@@ -19,6 +21,10 @@ public class Softmax extends CompositeTensorFunction {
this.argument = argument;
this.dimension = dimension;
}
+
+ public static TensorType outputType(TensorType inputType, String dimension) {
+ return Reduce.outputType(inputType, ImmutableList.of(dimension));
+ }
@Override
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index cabcce198d1..533a46f87fe 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -12,7 +12,7 @@ import java.util.List;
* A representation of a tensor function which is able to be translated to a set of primitive
* tensor functions if necessary.
* All tensor functions are immutable.
- *
+ *
* @author bratseth
*/
@Beta
@@ -48,11 +48,11 @@ public abstract class TensorFunction {
/**
* Return a string representation of this context.
- *
+ *
* @param context a context which must be passed to all nexted functions when requesting the string value
*/
public abstract String toString(ToStringContext context);
-
+
@Override
public String toString() { return toString(ToStringContext.empty()); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
index e8c425d49e0..416b28afa22 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
@@ -24,7 +24,7 @@ interface BinaryFormat {
/**
* Deserialize the given binary data into a Tensor object.
- *
+ *
* @param type the expected abstract type of the tensor to serialize, or empty to use type information from the data
* @param buffer the buffer containing the tensor binary data
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 8b7325ec211..1e830bac461 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -16,9 +16,9 @@ import java.util.Optional;
*
* Sorted dimensions = num_dimensions [dimension_str_len dimension_str_bytes dimension_size_int]*
* Cell_values = [double, double, double, ...]*
- * where values are encoded in order of increasing indexes in each dimension, increasing
+ * where values are encoded in order of increasing indexes in each dimension, increasing
* indexes of later dimensions in the dimension type before earlier.
- *
+ *
* @author bratseth
*/
@Beta
@@ -36,7 +36,7 @@ public class DenseBinaryFormat implements BinaryFormat {
buffer.putInt1_4Bytes(tensor.type().dimensions().size());
for (int i = 0; i < tensor.type().dimensions().size(); i++) {
buffer.putUtf8String(tensor.type().dimensions().get(i).name());
- buffer.putInt1_4Bytes(tensor.dimensionSizes().size(i));
+ buffer.putInt1_4Bytes((int)tensor.dimensionSizes().size(i)); // XXX: Size truncation
}
}
@@ -54,7 +54,7 @@ public class DenseBinaryFormat implements BinaryFormat {
type = optionalType.get();
TensorType serializedType = decodeType(buffer);
if ( ! serializedType.isAssignableTo(type))
- throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
+ throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
" cannot be assigned to type " + type);
sizes = sizesFromType(serializedType);
}
@@ -71,7 +71,7 @@ public class DenseBinaryFormat implements BinaryFormat {
int dimensionCount = buffer.getInt1_4Bytes();
TensorType.Builder builder = new TensorType.Builder();
for (int i = 0; i < dimensionCount; i++)
- builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes());
+ builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); // XXX: Size truncation
return builder.build();
}
@@ -84,7 +84,7 @@ public class DenseBinaryFormat implements BinaryFormat {
}
private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
- for (int i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < sizes.totalSize(); i++)
builder.cellByDirectIndex(i, buffer.getDouble());
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
index 61dfa888567..34e6cccf0f0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
@@ -46,16 +46,16 @@ class MixedBinaryFormat implements BinaryFormat {
buffer.putInt1_4Bytes(denseDimensions.size());
for (TensorType.Dimension dimension : denseDimensions) {
buffer.putUtf8String(dimension.name());
- buffer.putInt1_4Bytes(dimension.size().orElseThrow(() ->
- new IllegalArgumentException("Unknown size of indexed dimension.")));
+ buffer.putInt1_4Bytes((int)dimension.size().orElseThrow(() ->
+ new IllegalArgumentException("Unknown size of indexed dimension.")).longValue()); // XXX: Size truncation
}
}
private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor) {
List<TensorType.Dimension> sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
- int denseSubspaceSize = tensor.denseSubspaceSize();
+ long denseSubspaceSize = tensor.denseSubspaceSize();
if (sparseDimensions.size() > 0) {
- buffer.putInt1_4Bytes(tensor.size() / denseSubspaceSize);
+ buffer.putInt1_4Bytes((int)(tensor.size() / denseSubspaceSize)); // XXX: Size truncation
}
Iterator<Tensor.Cell> cellIterator = tensor.cellIterator();
while (cellIterator.hasNext()) {
@@ -98,7 +98,7 @@ class MixedBinaryFormat implements BinaryFormat {
}
int numIndexedDimensions = buffer.getInt1_4Bytes();
for (int i = 0; i < numIndexedDimensions; ++i) {
- builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes());
+ builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); // XXX: Size truncation
}
return builder.build();
}
@@ -106,21 +106,21 @@ class MixedBinaryFormat implements BinaryFormat {
private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) {
List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
TensorType sparseType = MixedTensor.createPartialType(sparseDimensions);
- int denseSubspaceSize = builder.denseSubspaceSize();
+ long denseSubspaceSize = builder.denseSubspaceSize();
int numBlocks = 1;
if (sparseDimensions.size() > 0) {
numBlocks = buffer.getInt1_4Bytes();
}
- double[] denseSubspace = new double[denseSubspaceSize];
+ double[] denseSubspace = new double[(int)denseSubspaceSize];
for (int i = 0; i < numBlocks; ++i) {
TensorAddress.Builder sparseAddress = new TensorAddress.Builder(sparseType);
for (TensorType.Dimension sparseDimension : sparseDimensions) {
sparseAddress.add(sparseDimension.name(), buffer.getUtf8String());
}
- for (int denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) {
- denseSubspace[denseOffset] = buffer.getDouble();
+ for (long denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) {
+ denseSubspace[(int)denseOffset] = buffer.getDouble();
}
builder.block(sparseAddress.build(), denseSubspace);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index 19969506eca..0cd3ff77aca 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -3,13 +3,14 @@ package com.yahoo.tensor.serialization;
import com.google.common.annotations.Beta;
import com.yahoo.io.GrowableByteBuffer;
-import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import com.yahoo.text.Utf8;
-import java.util.*;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
/**
* Implementation of a sparse binary format for a tensor on the form:
@@ -39,7 +40,7 @@ class SparseBinaryFormat implements BinaryFormat {
}
private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
- buffer.putInt1_4Bytes(tensor.size());
+ buffer.putInt1_4Bytes((int)tensor.size()); // XXX: Size truncation
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
encodeAddress(buffer, cell.getKey());
@@ -79,8 +80,8 @@ class SparseBinaryFormat implements BinaryFormat {
}
private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) {
- int numCells = buffer.getInt1_4Bytes();
- for (int i = 0; i < numCells; ++i) {
+ long numCells = buffer.getInt1_4Bytes(); // XXX: Size truncation
+ for (long i = 0; i < numCells; ++i) {
Tensor.Builder.CellBuilder cellBuilder = builder.cell();
decodeAddress(buffer, cellBuilder, type);
cellBuilder.value(buffer.getDouble());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 7467554790a..01a1d023f2b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -46,9 +46,9 @@ public class TypedBinaryFormat {
return result;
}
- /**
- * Decode some data to a tensor
- *
+ /**
+ * Decode some data to a tensor
+ *
* @param type the type to decode and validate to, or empty to use the type given in the data
* @param buffer the buffer containing the data, use GrowableByteByffer.wrap(byte[]) if you have a byte array
* @return the resulting tensor
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
index d199dd3a876..abdb3071bf7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
@@ -13,14 +13,14 @@ import java.util.stream.Collectors;
/**
* Microbenchmark of tensor operations.
- *
+ *
* @author bratseth
*/
public class TensorFunctionBenchmark {
private final static Random random = new Random();
-
- public double benchmark(int iterations, List<Tensor> modelVectors, TensorType.Dimension.Type dimensionType,
+
+ public double benchmark(int iterations, List<Tensor> modelVectors, TensorType.Dimension.Type dimensionType,
boolean extraSpace) {
Tensor queryVector = vectors(1, 300, dimensionType).get(0);
if (extraSpace) {
@@ -34,7 +34,7 @@ public class TensorFunctionBenchmark {
long totalTime = System.currentTimeMillis() - startTime;
return (double)totalTime / (double)iterations;
}
-
+
private Tensor unitVector(String dimension) {
return Tensor.Builder.of(new TensorType.Builder().indexed(dimension, 1).build())
.cell().label(dimension, 0).value(1).build();
@@ -49,11 +49,11 @@ public class TensorFunctionBenchmark {
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double largest = Double.MIN_VALUE;
- TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
- new VariableTensor("argument"), (a, b) -> a * b),
+ TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
+ new VariableTensor("argument"), (a, b) -> a * b),
Reduce.Aggregator.sum).toPrimitive();
MapEvaluationContext context = new MapEvaluationContext();
-
+
for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor
context.put("argument", tensorElement);
double dotProduct = dotProductFunction.evaluate(context).asDouble();
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 30078b4a826..38a8329bff1 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -4,7 +4,6 @@ package com.yahoo.tensor;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.evaluation.MapEvaluationContext;
import com.yahoo.tensor.evaluation.VariableTensor;
-import com.yahoo.tensor.functions.Argmax;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Reduce;
@@ -12,20 +11,18 @@ import com.yahoo.tensor.functions.TensorFunction;
import org.junit.Test;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
-import java.util.stream.Collectors;
-import static org.junit.Assert.assertEquals;
import static com.yahoo.tensor.TensorType.Dimension.Type;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
* Tests tensor functionality
- *
+ *
* @author bratseth
*/
public class TensorTestCase {
@@ -99,7 +96,7 @@ public class TensorTestCase {
ImmutableList.of("y", "x")));
assertEquals(Tensor.from("{ {x:0,y:0}:0, {x:0,y:1}:0, {x:1,y:0}:0, {x:1,y:1}:1, {x:2,y:0}:0, {x:2,y:1}:2, }"),
Tensor.generate(new TensorType.Builder().indexed("x", 3).indexed("y", 2).build(),
- (List<Integer> indexes) -> (double)indexes.get(0)*indexes.get(1)));
+ (List<Long> indexes) -> (double)indexes.get(0)*indexes.get(1)));
assertEquals(Tensor.from("{ {x:0,y:0,z:0}:0, {x:0,y:1,z:0}:1, {x:1,y:0,z:0}:1, {x:1,y:1,z:0}:2, {x:2,y:0,z:0}:2, {x:2,y:1,z:0}:3, "+
" {x:0,y:0,z:1}:1, {x:0,y:1,z:1}:2, {x:1,y:0,z:1}:2, {x:1,y:1,z:1}:3, {x:2,y:0,z:1}:3, {x:2,y:1,z:1}:4 }"),
Tensor.range(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build()));
@@ -108,7 +105,7 @@ public class TensorTestCase {
Tensor.diag(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build()));
assertEquals(Tensor.from("{ {x:1}:0, {x:3}:1, {x:9}:0 }"), Tensor.from("{ {x:1}:1, {x:3}:5, {x:9}:3 }").argmax("x"));
}
-
+
/** Test the same computation made in various ways which are implemented with special-case optimizations */
@Test
public void testOptimizedComputation() {
@@ -130,7 +127,7 @@ public class TensorTestCase {
assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.indexedUnbound), vectors(Type.mapped, 2)));
assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2)));
assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2)));
-
+
// Test the unoptimized path by joining in another dimension
Tensor unitJ = Tensor.Builder.of(new TensorType.Builder().mapped("j").build()).cell().label("j", 0).value(1).build();
Tensor unitK = Tensor.Builder.of(new TensorType.Builder().mapped("k").build()).cell().label("k", 0).value(1).build();
@@ -138,7 +135,7 @@ public class TensorTestCase {
Tensor matrixInKSpace = matrix(Type.mapped, 2).get(0).multiply(unitK);
assertEquals("Generic computation implementation", 42, (int)dotProduct(vectorInJSpace, Collections.singletonList(matrixInKSpace)));
}
-
+
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double sum = 0;
TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
@@ -161,7 +158,7 @@ public class TensorTestCase {
private Tensor vector(int vectorSize, TensorType.Dimension.Type dimensionType) {
return vectors(vectorSize, dimensionType, 1).get(0);
}
-
+
/** Create a list of vectors having a single dimension x */
private List<Tensor> vectors(TensorType.Dimension.Type dimensionType, int vectorCount) {
return vectors(3, dimensionType, vectorCount);
@@ -179,8 +176,8 @@ public class TensorTestCase {
}
return tensors;
}
-
- /**
+
+ /**
* Create a matrix of vectors (in dimension i) where each vector has the dimension x.
* This matrix contains the same vectors as returned by createVectors, in a single list element for convenience.
*/
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java
index fab53218b2c..f11c068bd74 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java
@@ -10,12 +10,12 @@ import static org.junit.Assert.assertEquals;
* @author bratseth
*/
public class JoinTestCase {
-
+
/** Test the indexed subspace join optimization */
@Test
public void testJoinIndexedSubspace() {
Tensor t1, t2;
-
+
t1 = Tensor.from("tensor(x[]):{{x:0}:1.0,{x:1}:2.0}");
t2 = Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:10,{x:1,y:1,z:0}:0.0}");
assertEquals(Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:20.0,{x:1,y:1,z:0}:0.0}"),
@@ -34,10 +34,10 @@ public class JoinTestCase {
assertEquals(Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:10.0,{x:1,y:1,z:0}:0.0}"),
t2.divide(t1));
}
-
+
@Test
public void testGeneralJoin() {
- assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"),
+ assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"),
Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:4, {x:2}:6 }")
.divide(Tensor.from("tensor(y[]):{{y:0}:2}")));
@@ -45,5 +45,5 @@ public class JoinTestCase {
Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:6, {x:1,y:0}:8, {x:0,y:1}:20, {x:1,y:1}:24 }")
.divide(Tensor.from("tensor(y[],z[]):{ {y:0,z:0}:2, {y:1,z:0}:4, {y:2,z:0}:6 }")));
}
-
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
new file mode 100644
index 00000000000..9643c0a56e7
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
@@ -0,0 +1,97 @@
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class MatmulTestCase {
+
+ @Test
+ public void testMatmul2d() {
+ // d0 is the 'outermost' dimension, etc.
+ Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3])"));
+ ab.cell( 1,0, 0);
+ ab.cell( 2,0, 1);
+ ab.cell( 3,0, 2);
+ ab.cell( 4,1, 0);
+ ab.cell( 5,1, 1);
+ ab.cell( 6,1, 2);
+ Tensor a = ab.build();
+
+ Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[3],d1[2])"));
+ bb.cell( 7,0, 0);
+ bb.cell( 8,0, 1);
+ bb.cell( 9,1, 0);
+ bb.cell(10,1, 1);
+ bb.cell(11,2, 0);
+ bb.cell(12,2, 1);
+ Tensor b = bb.build();
+
+ Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2])"));
+ rb.cell( 58,0, 0);
+ rb.cell( 64,0, 1);
+ rb.cell(139,1, 0);
+ rb.cell(154,1, 1);
+ Tensor r = rb.build();
+
+ Tensor result = a.matmul(b.rename(ImmutableList.of("d0","d1"), ImmutableList.of("d1","d2")), "d1")
+ .rename("d2","d1");
+ assertEquals(r, result);
+ }
+
+ @Test
+ public void testMatmul3d() {
+ // Convention: a is the 'outermost' dimension, etc.
+ Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[3])"));
+ ab.cell( 1,0, 0, 0);
+ ab.cell( 2,0, 0, 1);
+ ab.cell( 3,0, 0, 2);
+ ab.cell( 4,0, 1, 0);
+ ab.cell( 5,0, 1, 1);
+ ab.cell( 6,0, 1, 2);
+ ab.cell( 7,1, 0, 0);
+ ab.cell( 8,1, 0, 1);
+ ab.cell( 9,1, 0, 2);
+ ab.cell(10,1, 1, 0);
+ ab.cell(11,1, 1, 1);
+ ab.cell(12,1, 1, 2);
+ Tensor a = ab.build();
+
+ Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3],d2[2])"));
+ bb.cell(13,0, 0, 0);
+ bb.cell(14,0, 0, 1);
+ bb.cell(15,0, 1, 0);
+ bb.cell(16,0, 1, 1);
+ bb.cell(17,0, 2, 0);
+ bb.cell(18,0, 2, 1);
+ bb.cell(19,1, 0, 0);
+ bb.cell(20,1, 0, 1);
+ bb.cell(21,1, 1, 0);
+ bb.cell(22,1, 1, 1);
+ bb.cell(23,1, 2, 0);
+ bb.cell(24,1, 2, 1);
+ Tensor b = bb.build();
+
+ Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[2])"));
+ rb.cell( 94,0, 0, 0);
+ rb.cell(100,0, 0, 1);
+ rb.cell(229,0, 1, 0);
+ rb.cell(244,0, 1, 1);
+ rb.cell(508,1, 0, 0);
+ rb.cell(532,1, 0, 1);
+ rb.cell(697,1, 1, 0);
+ rb.cell(730,1, 1, 1);
+ Tensor r = rb.build();
+
+ Tensor result = a.matmul(b.rename(ImmutableList.of("d1","d2"), ImmutableList.of("d2","d3")), "d2")
+ .rename("d3","d2");
+ assertEquals(r, result);
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
index 8a58cb0bbed..55069eaced7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
@@ -7,7 +7,7 @@ import static org.junit.Assert.assertEquals;
/**
* Tests translation of composite to primitive tensor function translation.
- *
+ *
* @author bratseth
*/
public class TensorFunctionTestCase {
@@ -16,12 +16,12 @@ public class TensorFunctionTestCase {
public void testTranslation() {
assertTranslated("join({{x:1}:1.0}, reduce({{x:1}:1.0}, sum, x), f(a,b)(a / b))",
new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x"));
- assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))",
+ assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))",
new Diag(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build()));
assertTranslated("join({{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, reduce({{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, max, x), f(a,b)(a==b))",
new Argmax(new ConstantTensor("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x"));
}
-
+
private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) {
assertEquals(expectedTranslation, inputFunction.toPrimitive().toString());
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index 349309a5052..15a872e439f 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -30,7 +30,7 @@ public class DenseBinaryFormatTestCase {
assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor(x[1],y[2],z[3]):{{y:0,x:0,z:0}:2.0}");
}
-
+
@Test
public void testSerializationToSeparateType() {
assertSerialization(Tensor.from("tensor(x[1],y[1]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[],y[])"));
@@ -64,7 +64,7 @@ public class DenseBinaryFormatTestCase {
private void assertSerialization(Tensor tensor) {
assertSerialization(tensor, tensor.type());
}
-
+
private void assertSerialization(Tensor tensor, TensorType expectedType) {
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor));
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
index b1d7d797b3e..33dfca017f4 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
@@ -84,7 +84,7 @@ public class MixedBinaryFormatTestCase {
private void assertSerialization(Tensor tensor) {
assertSerialization(tensor, tensor.type());
}
-
+
private void assertSerialization(Tensor tensor, TensorType expectedType) {
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor));
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
index 68bf59e3ed9..f002637847b 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
@@ -50,7 +50,7 @@ public class SerializationTestCase {
JsonNode node = mapper.readTree(test);
if (node.has("tensor") && node.has("binary")) {
System.out.println("Running test: " + test);
-
+
Tensor tensor = buildTensor(node.get("tensor"));
String spec = getSpec(node.get("tensor"));
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
@@ -123,7 +123,7 @@ public class SerializationTestCase {
private byte[] getBytes(String binaryRepresentation) {
return parseHexValue(binaryRepresentation.substring(2));
}
-
+
private byte[] parseHexValue(String s) {
final int len = s.length();
byte[] bytes = new byte[len/2];
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index d17148cf8dc..f895b64379b 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -65,7 +65,7 @@ public class SparseBinaryFormatTestCase {
private void assertSerialization(Tensor tensor, TensorType expectedType) {
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
- Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType),
+ Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType),
GrowableByteBuffer.wrap(encodedTensor));
assertEquals(tensor, decodedTensor);
}
diff --git a/vespalib/src/vespa/vespalib/stllike/hash_map.h b/vespalib/src/vespa/vespalib/stllike/hash_map.h
index 31185a9ff7c..023594d3018 100644
--- a/vespalib/src/vespa/vespalib/stllike/hash_map.h
+++ b/vespalib/src/vespa/vespalib/stllike/hash_map.h
@@ -35,7 +35,7 @@ public:
size_t capacity() const { return _ht.capacity(); }
size_t size() const { return _ht.size(); }
bool empty() const { return _ht.empty(); }
- insert_result insert(const value_type & value);
+ insert_result insert(const value_type & value) { return _ht.insert(value); }
template <typename InputIt>
void insert(InputIt first, InputIt last);
const V & operator [] (const K & key) const { return _ht.find(key)->second; }
diff --git a/vespalib/src/vespa/vespalib/stllike/hash_map.hpp b/vespalib/src/vespa/vespalib/stllike/hash_map.hpp
index 359ba235a36..b526188b8b2 100644
--- a/vespalib/src/vespa/vespalib/stllike/hash_map.hpp
+++ b/vespalib/src/vespa/vespalib/stllike/hash_map.hpp
@@ -17,13 +17,7 @@ hash_map<K, V, H, EQ, M>::hash_map(size_t reserveSize, H hasher, EQ equality) :
{ }
template <typename K, typename V, typename H, typename EQ, typename M>
-hash_map<K, V, H, EQ, M>::~hash_map() { }
-
-template <typename K, typename V, typename H, typename EQ, typename M>
-typename hash_map<K, V, H, EQ, M>::insert_result
-hash_map<K, V, H, EQ, M>::insert(const value_type & value) {
- return _ht.insert(value);
-}
+hash_map<K, V, H, EQ, M>::~hash_map() = default;
template <typename K, typename V, typename H, typename EQ, typename M>
void
@@ -64,12 +58,20 @@ hash_map<K, V, H, EQ, M>::getMemoryUsed() const
}
-#define VESPALIB_HASH_MAP_INSTANTIATE_H(K, V, H) \
- template class vespalib::hash_map<K, V, H>; \
- template class vespalib::hashtable<K, std::pair<K,V>, H, std::equal_to<K>, std::_Select1st<std::pair<K,V>>>; \
- template vespalib::hashtable<K, std::pair<K,V>, H, std::equal_to<K>, std::_Select1st<std::pair<K,V>>>::insert_result \
- vespalib::hashtable<K, std::pair<K,V>, H, std::equal_to<K>, std::_Select1st<std::pair<K,V>>>::insert(std::pair<K,V> &&); \
+
+#define VESPALIB_HASH_MAP_INSTANTIATE_H_E_M(K, V, H, E, M) \
+ template class vespalib::hash_map<K, V, H, E, M>; \
+ template class vespalib::hashtable<K, std::pair<K,V>, H, E, std::_Select1st<std::pair<K,V>>, M>; \
+ template vespalib::hashtable<K, std::pair<K,V>, H, E, std::_Select1st<std::pair<K,V>>, M>::insert_result \
+ vespalib::hashtable<K, std::pair<K,V>, H, E, std::_Select1st<std::pair<K,V>>, M>::insert(std::pair<K,V> &&); \
+ template vespalib::hashtable<K, std::pair<K,V>, H, E, std::_Select1st<std::pair<K,V>>, M>::insert_result \
+ vespalib::hashtable<K, std::pair<K,V>, H, E, std::_Select1st<std::pair<K,V>>, M>::insertInternal(std::pair<K,V> &&); \
template class vespalib::Array<vespalib::hash_node<std::pair<K,V>>>;
+#define VESPALIB_HASH_MAP_INSTANTIATE_H_E(K, V, H, E) \
+ VESPALIB_HASH_MAP_INSTANTIATE_H_E_M(K, V, H, E, vespalib::hashtable_base::prime_modulator)
+
+#define VESPALIB_HASH_MAP_INSTANTIATE_H(K, V, H) VESPALIB_HASH_MAP_INSTANTIATE_H_E(K, V, H, std::equal_to<K>)
+
#define VESPALIB_HASH_MAP_INSTANTIATE(K, V) VESPALIB_HASH_MAP_INSTANTIATE_H(K, V, vespalib::hash<K>)
diff --git a/vespalib/src/vespa/vespalib/stllike/hashtable.h b/vespalib/src/vespa/vespalib/stllike/hashtable.h
index 263ee952c2e..15949067a60 100644
--- a/vespalib/src/vespa/vespalib/stllike/hashtable.h
+++ b/vespalib/src/vespa/vespalib/stllike/hashtable.h
@@ -141,19 +141,18 @@ public:
typedef Value* pointer;
typedef std::forward_iterator_tag iterator_category;
- iterator(hashtable * hash, next_t start) : _hash(start), _subNode(start), _hashTable(hash) {
- advanceToNextValidHash();
- }
- iterator(hashtable * hash, next_t start, next_t subNode) : _hash(start), _subNode(subNode), _hashTable(hash) { }
- Value & operator * () const { return _hashTable->get(_subNode); }
- Value * operator -> () const { return & _hashTable->get(_subNode); }
- iterator & operator ++ () {
- if (_hashTable->_nodes[_subNode].hasNext()) {
- _subNode = _hashTable->_nodes[_subNode].getNext();
- } else {
- _hash++;
+ iterator(hashtable * hash) : _current(0), _hashTable(hash) {
+ if ((_current < _hashTable->initializedSize()) && ! _hashTable->_nodes[_current].valid()) {
advanceToNextValidHash();
}
+ }
+ iterator(hashtable * hash, next_t pos) : _current(pos), _hashTable(hash) { }
+ static iterator end(hashtable *hash) { return iterator(hash, Node::npos); }
+
+ Value & operator * () const { return _hashTable->get(_current); }
+ Value * operator -> () const { return & _hashTable->get(_current); }
+ iterator & operator ++ () {
+ advanceToNextValidHash();
return *this;
}
iterator operator ++ (int) {
@@ -161,19 +160,19 @@ public:
++(*this);
return prev;
}
- bool operator==(const iterator& rhs) const { return (_subNode == rhs._subNode); }
- bool operator!=(const iterator& rhs) const { return (_subNode != rhs._subNode); }
+ bool operator==(const iterator& rhs) const { return (_current == rhs._current); }
+ bool operator!=(const iterator& rhs) const { return (_current != rhs._current); }
/// Carefull about this one. Only used by lrucache.
- next_t getInternalIndex() const { return _subNode; }
- void setInternalIndex(next_t n) { _subNode = n; }
- next_t getHash() const { return _hash; }
+ next_t getInternalIndex() const { return _current; }
+ void setInternalIndex(next_t n) { _current = n; }
private:
void advanceToNextValidHash() {
- for (;(_hash < _hashTable->getTableSize()) && ! _hashTable->_nodes[_hash].valid(); _hash++) { }
- _subNode = (_hash < _hashTable->getTableSize()) ? _hash : Node::npos;
+ for (_current++;(_current < _hashTable->initializedSize()) && ! _hashTable->_nodes[_current].valid(); _current++) { }
+ if (_current >= _hashTable->initializedSize()) {
+ _current = Node::npos;
+ }
}
- next_t _hash;
- next_t _subNode;
+ next_t _current;
hashtable * _hashTable;
friend class hashtable::const_iterator;
@@ -186,21 +185,19 @@ public:
typedef const Value* pointer;
typedef std::forward_iterator_tag iterator_category;
- const_iterator(const hashtable * hash, next_t start) : _hash(start), _subNode(start), _hashTable(hash) {
- advanceToNextValidHash();
- }
- const_iterator(const hashtable * hash, next_t start, next_t subNode) : _hash(start), _subNode(subNode), _hashTable(hash) { }
- const_iterator(const iterator &i)
- : _hash(i._hash), _subNode(i._subNode), _hashTable(i._hashTable) {}
- const Value & operator * () const { return _hashTable->get(_subNode); }
- const Value * operator -> () const { return & _hashTable->get(_subNode); }
- const_iterator & operator ++ () {
- if (_hashTable->_nodes[_subNode].hasNext()) {
- _subNode = _hashTable->_nodes[_subNode].getNext();
- } else {
- _hash++;
+ const_iterator(const hashtable * hash) : _current(0), _hashTable(hash) {
+ if ((_current < _hashTable->initializedSize()) && ! _hashTable->_nodes[_current].valid()) {
advanceToNextValidHash();
}
+ }
+ const_iterator(const hashtable * hash, next_t pos) : _current(pos), _hashTable(hash) { }
+ const_iterator(const iterator &i) : _current(i._current), _hashTable(i._hashTable) {}
+ static const_iterator end(const hashtable *hash) { return const_iterator(hash, Node::npos); }
+
+ const Value & operator * () const { return _hashTable->get(_current); }
+ const Value * operator -> () const { return & _hashTable->get(_current); }
+ const_iterator & operator ++ () {
+ advanceToNextValidHash();
return *this;
}
const_iterator operator ++ (int) {
@@ -208,17 +205,17 @@ public:
++(*this);
return prev;
}
- bool operator==(const const_iterator& rhs) const { return (_subNode == rhs._subNode); }
- bool operator!=(const const_iterator& rhs) const { return (_subNode != rhs._subNode); }
- next_t getInternalIndex() const { return _subNode; }
- next_t getHash() const { return _hash; }
+ bool operator==(const const_iterator& rhs) const { return (_current == rhs._current); }
+ bool operator!=(const const_iterator& rhs) const { return (_current != rhs._current); }
+ next_t getInternalIndex() const { return _current; }
private:
void advanceToNextValidHash() {
- for (;(_hash < _hashTable->getTableSize()) && ! _hashTable->_nodes[_hash].valid(); _hash++) { }
- _subNode = (_hash < _hashTable->getTableSize()) ? _hash : Node::npos;
+ for (_current++;(_current < _hashTable->initializedSize()) && ! _hashTable->_nodes[_current].valid(); _current++) { }
+ if (_current >= _hashTable->initializedSize()) {
+ _current = Node::npos;
+ }
}
- next_t _hash;
- next_t _subNode;
+ next_t _current;
const hashtable * _hashTable;
};
typedef std::pair<iterator, bool> insert_result;
@@ -231,10 +228,10 @@ public:
hashtable(size_t reservedSpace);
hashtable(size_t reservedSpace, const Hash & hasher, const Equal & equal);
virtual ~hashtable();
- iterator begin() { return iterator(this, 0); }
- iterator end() { return iterator(this, Node::npos); }
- const_iterator begin() const { return const_iterator(this, 0); }
- const_iterator end() const { return const_iterator(this, Node::npos); }
+ iterator begin() { return iterator(this); }
+ iterator end() { return iterator::end(this); }
+ const_iterator begin() const { return const_iterator(this); }
+ const_iterator end() const { return const_iterator::end(this); }
size_t capacity() const { return _nodes.capacity(); }
size_t size() const { return _count; }
bool empty() const { return _count == 0; }
@@ -249,7 +246,9 @@ public:
const_iterator find(const AltKey & key) const { return find<AltKey, AltExtract, AltHash, AltEqual>(key, AltExtract()); }
const_iterator find(const Key & key) const;
template <typename V>
- insert_result insert(V && node);
+ insert_result insert(V && node) {
+ return insertInternal(std::forward<V>(node));
+ }
void erase(const Key & key);
void reserve(size_t sz) {
if (sz > _nodes.capacity()) {
@@ -280,7 +279,8 @@ protected:
Value & getByInternalIndex(size_t index) { return _nodes[index].getValue(); }
const Value & getByInternalIndex(size_t index) const { return _nodes[index].getValue(); }
template <typename MoveHandler>
- void erase(MoveHandler & moveHandler, const const_iterator & key);
+ void erase(MoveHandler & moveHandler, next_t h, const const_iterator & key);
+ next_t hash(const Key & key) const { return modulator(_hasher(key)); }
private:
Modulator _modulator;
size_t _count;
@@ -292,7 +292,7 @@ private:
const Value & get(size_t index) const { return _nodes[index].getValue(); }
next_t modulator(next_t key) const { return _modulator.modulo(key); }
next_t getTableSize() const { return _modulator.getTableSize(); }
- next_t hash(const Key & key) const { return modulator(_hasher(key)); }
+ size_t initializedSize() const { return _nodes.size(); }
template <typename MoveHandler>
void move(MoveHandler & moveHandler, next_t from, next_t to) {
_nodes[to] = std::move(_nodes[from]);
diff --git a/vespalib/src/vespa/vespalib/stllike/hashtable.hpp b/vespalib/src/vespa/vespalib/stllike/hashtable.hpp
index 359e71aa0d2..f499ba35f3f 100644
--- a/vespalib/src/vespa/vespalib/stllike/hashtable.hpp
+++ b/vespalib/src/vespa/vespalib/stllike/hashtable.hpp
@@ -67,11 +67,10 @@ typename hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::iterator
hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::find(const Key & key)
{
next_t h = hash(key);
- if (_nodes[h].valid()) {
- next_t start(h);
+ if (__builtin_expect(_nodes[h].valid(), true)) {
do {
- if (_equal(_keyExtractor(_nodes[h].getValue()), key)) {
- return iterator(this, start, h);
+ if (__builtin_expect(_equal(_keyExtractor(_nodes[h].getValue()), key), true)) {
+ return iterator(this, h);
}
h = _nodes[h].getNext();
} while (h != Node::npos);
@@ -84,11 +83,10 @@ typename hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::const_iterat
hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::find(const Key & key) const
{
next_t h = hash(key);
- if (_nodes[h].valid()) {
- next_t start(h);
+ if (__builtin_expect(_nodes[h].valid(), true)) {
do {
- if (_equal(_keyExtractor(_nodes[h].getValue()), key)) {
- return const_iterator(this, start, h);
+ if (__builtin_expect(_equal(_keyExtractor(_nodes[h].getValue()), key), true)) {
+ return const_iterator(this, h);
}
h = _nodes[h].getNext();
} while (h != Node::npos);
@@ -104,11 +102,10 @@ hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::find(const AltKey & k
AltHash altHasher;
next_t h = modulator(altHasher(key));
if (_nodes[h].valid()) {
- next_t start(h);
AltEqual altEqual;
do {
if (altEqual(altExtract(_keyExtractor(_nodes[h].getValue())), key)) {
- return const_iterator(this, start, h);
+ return const_iterator(this, h);
}
h = _nodes[h].getNext();
} while (h != Node::npos);
@@ -124,11 +121,10 @@ hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::find(const AltKey & k
AltHash altHasher;
next_t h = modulator(altHasher(key));
if (_nodes[h].valid()) {
- next_t start(h);
AltEqual altEqual;
do {
if (altEqual(altExtract(_keyExtractor(_nodes[h].getValue())), key)) {
- return iterator(this, start, h);
+ return iterator(this, h);
}
h = _nodes[h].getNext();
} while (h != Node::npos);
@@ -137,19 +133,12 @@ hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::find(const AltKey & k
}
template< typename Key, typename Value, typename Hash, typename Equal, typename KeyExtract, typename Modulator >
-template<typename V>
-typename hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::insert_result
-hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::insert(V && node) {
- return insertInternal(std::forward<V>(node));
-}
-
-template< typename Key, typename Value, typename Hash, typename Equal, typename KeyExtract, typename Modulator >
void
hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::erase(const Key & key) {
const_iterator found(find(key));
if (found != end()) {
DefaultMoveHandler moveHandler;
- erase(moveHandler, found);
+ erase(moveHandler, hash(key), found);
}
}
@@ -169,11 +158,11 @@ hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::insertInternal(V && n
if ( ! _nodes[h].valid() ) {
_nodes[h] = std::forward<V>(node);
_count++;
- return insert_result(iterator(this, h, h), true);
+ return insert_result(iterator(this, h), true);
} else if (_nodes.size() <= _nodes.capacity()) {
for (next_t c(h); c != Node::npos; c = _nodes[c].getNext()) {
if (_equal(_keyExtractor(_nodes[c].getValue()), _keyExtractor(node))) {
- return insert_result(iterator(this, h, c), false);
+ return insert_result(iterator(this, c), false);
}
}
if (_nodes.size() < _nodes.capacity()) {
@@ -182,7 +171,7 @@ hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::insertInternal(V && n
_nodes[h].setNext(newIdx);
new (_nodes.push_back_fast()) Node(std::forward<V>(node), p);
_count++;
- return insert_result(iterator(this, h, newIdx), true);
+ return insert_result(iterator(this, newIdx), true);
} else {
resize(_nodes.capacity()*2);
return insertInternal(std::forward<V>(node));
@@ -214,9 +203,8 @@ void hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::reclaim(MoveHand
template< typename Key, typename Value, typename Hash, typename Equal, typename KeyExtract, typename Modulator >
template <typename MoveHandler>
void
-hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::erase(MoveHandler & moveHandler, const const_iterator & it)
+hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::erase(MoveHandler & moveHandler, next_t h, const const_iterator & it)
{
- next_t h = it.getHash();
next_t prev = Node::npos;
do {
if (h == it.getInternalIndex()) {
diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java
index 15257e11cbe..4c932969460 100644
--- a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java
+++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java
@@ -3,6 +3,7 @@ package com.yahoo.vespa.curator;
import com.google.inject.Inject;
import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.net.HostName;
import com.yahoo.path.Path;
import com.yahoo.vespa.curator.recipes.CuratorCounter;
import com.yahoo.vespa.zookeeper.ZooKeeperServer;
@@ -21,7 +22,6 @@ import org.apache.curator.framework.state.ConnectionState;
import org.apache.curator.framework.state.ConnectionStateListener;
import org.apache.curator.retry.ExponentialBackoffRetry;
-import java.io.Closeable;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
@@ -68,16 +68,26 @@ public class Curator implements AutoCloseable {
this(createConnectionSpec(configserverConfig));
}
- private static String createConnectionSpec(ConfigserverConfig config) {
+ static String createConnectionSpec(ConfigserverConfig config) {
+ String thisServer = HostName.getLocalhost();
+
StringBuilder sb = new StringBuilder();
for (int i = 0; i < config.zookeeperserver().size(); i++) {
ConfigserverConfig.Zookeeperserver server = config.zookeeperserver(i);
- sb.append(server.hostname());
- sb.append(":");
- sb.append(server.port());
- if (i < config.zookeeperserver().size() - 1) {
- sb.append(",");
+
+ String spec = String.format("%s:%d", server.hostname(), server.port());
+
+ if (config.zookeeperLocalhostAffinity() && server.hostname().equals(thisServer)) {
+ // Only connect to localhost server if possible, to save network traffic
+ // and balance load.
+ return spec;
}
+
+ if (sb.length() > 0) {
+ sb.append(',');
+ }
+
+ sb.append(spec);
}
return sb.toString();
}
diff --git a/zkfacade/src/test/java/com/yahoo/vespa/zookeeper/CuratorTest.java b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java
index 36205bdaca3..1899dcfe7cd 100644
--- a/zkfacade/src/test/java/com/yahoo/vespa/zookeeper/CuratorTest.java
+++ b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java
@@ -1,8 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.zookeeper;
+package com.yahoo.vespa.curator;
import com.yahoo.cloud.config.ConfigserverConfig;
-import com.yahoo.vespa.curator.Curator;
+import com.yahoo.net.HostName;
import org.apache.curator.test.TestingServer;
import org.junit.After;
import org.junit.Before;
@@ -11,7 +11,6 @@ import org.junit.Test;
import java.io.IOException;
import static org.hamcrest.core.Is.is;
-import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
/**
@@ -74,6 +73,23 @@ public class CuratorTest {
}
}
+ @Test
+ public void localhost_affinity() {
+ String localhostHostName = "myhost";
+ int localhostPort = 123;
+ String localhostSpec = localhostHostName + ":" + localhostPort;
+
+ ConfigserverConfig.Builder builder = new ConfigserverConfig.Builder();
+ builder.zookeeperserver(createZKBuilder(localhostHostName, localhostPort));
+ builder.zookeeperserver(createZKBuilder("otherhost", 345));
+ builder.zookeeperLocalhostAffinity(true);
+ ConfigserverConfig config = new ConfigserverConfig(builder);
+
+ HostName.setHostNameForTestingOnly(localhostHostName);
+
+ assertThat(Curator.createConnectionSpec(config), is(localhostSpec));
+ }
+
private ConfigserverConfig createTestConfig() {
ConfigserverConfig.Builder builder = new ConfigserverConfig.Builder();
builder.zookeeperserver(createZKBuilder("localhost", port1));