summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java7
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java2
-rw-r--r--configdefinitions/src/vespa/dispatch.def3
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerDB.java2
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java3
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java74
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java3
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/session/LocalSession.java12
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactory.java8
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactoryImpl.java1
-rw-r--r--container-search/src/main/java/com/yahoo/fs4/MapEncoder.java85
-rw-r--r--container-search/src/main/java/com/yahoo/search/Query.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/rpc/Client.java15
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcClient.java50
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcFillInvoker.java7
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcPing.java6
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcProtobufFillInvoker.java15
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcResourcePool.java71
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcSearchInvoker.java10
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java2
-rw-r--r--container-search/src/test/java/com/yahoo/fs4/test/RankFeaturesTestCase.java5
-rw-r--r--container-search/src/test/java/com/yahoo/prelude/fastsearch/test/FastSearcherTestCase.java6
-rw-r--r--container-search/src/test/java/com/yahoo/search/dispatch/rpc/FillTestCase.java15
-rw-r--r--container-search/src/test/java/com/yahoo/search/dispatch/rpc/MockClient.java113
-rw-r--r--container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java35
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java4
-rw-r--r--container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java47
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleId.java116
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserId.java40
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserManagement.java33
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/package-info.java5
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Action.java (renamed from controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Action.java)2
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/ApplicationRole.java29
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Context.java (renamed from controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Context.java)2
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java (renamed from controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/PathGroup.java)10
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java (renamed from controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Policy.java)13
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Privilege.java (renamed from controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Privilege.java)2
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Role.java48
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java (renamed from controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Role.java)54
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Roles.java104
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/SecurityContext.java51
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/TenantRole.java25
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/UnboundRole.java21
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/package-info.java5
-rw-r--r--controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleIdTest.java74
-rw-r--r--controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/role/PathGroupTest.java (renamed from controller-server/src/test/java/com/yahoo/vespa/hosted/controller/role/PathGroupTest.java)3
-rw-r--r--controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/role/RoleTest.java54
-rw-r--r--controller-server/pom.xml7
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java66
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java19
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java7
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/GlobalDnsName.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainer.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java17
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilter.java115
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolver.java118
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java38
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/user/UserApiHandler.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/RoleMembership.java122
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/rotation/RotationRepository.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java61
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControlRequests.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/TenantSpec.java3
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Tenant.java2
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java75
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java17
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainerTest.java1
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java4
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java39
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-cluster-global-rotation.json1
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference-2.json1
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference.json1
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json13
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json7
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json1
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2-with-majorVersion.json1
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json1
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json5
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json4
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-us-central-1.json8
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-application.json1
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilterTest.java122
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolverTest.java120
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilterTest.java56
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/role/RoleMembershipTest.java86
-rw-r--r--document/src/main/java/com/yahoo/document/CollectionDataType.java7
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java2
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java2
-rw-r--r--documentgen-test/etc/complex/music4.sd3
-rw-r--r--documentgen-test/src/test/java/com/yahoo/vespa/config/DocumentGenPluginTest.java69
-rw-r--r--eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp1
-rw-r--r--eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp200
-rw-r--r--eval/src/vespa/eval/eval/value_type.h6
-rw-r--r--eval/src/vespa/eval/eval/value_type_spec.cpp8
-rw-r--r--eval/src/vespa/eval/eval/value_type_spec.h8
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor.h6
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp6
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.h6
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor.cpp2
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor.h18
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp2
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp18
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h7
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h3
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.hpp8
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp8
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.h34
-rw-r--r--eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h2
-rw-r--r--eval/src/vespa/eval/tensor/join_tensors.h6
-rw-r--r--eval/src/vespa/eval/tensor/serialization/common.h9
-rw-r--r--eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp110
-rw-r--r--eval/src/vespa/eval/tensor/serialization/dense_binary_format.h21
-rw-r--r--eval/src/vespa/eval/tensor/serialization/slime_binary_format.cpp13
-rw-r--r--eval/src/vespa/eval/tensor/serialization/slime_binary_format.h11
-rw-r--r--eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp16
-rw-r--r--eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h9
-rw-r--r--eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp80
-rw-r--r--eval/src/vespa/eval/tensor/serialization/typed_binary_format.h26
-rw-r--r--eval/src/vespa/eval/tensor/tensor.cpp6
-rw-r--r--eval/src/vespa/eval/tensor/tensor.h8
-rw-r--r--eval/src/vespa/eval/tensor/tensor_address.cpp10
-rw-r--r--eval/src/vespa/eval/tensor/tensor_address_element_iterator.h3
-rw-r--r--eval/src/vespa/eval/tensor/tensor_builder.h6
-rw-r--r--eval/src/vespa/eval/tensor/tensor_factory.cpp24
-rw-r--r--eval/src/vespa/eval/tensor/tensor_factory.h9
-rw-r--r--eval/src/vespa/eval/tensor/tensor_mapper.cpp4
-rw-r--r--eval/src/vespa/eval/tensor/tensor_mapper.h6
-rw-r--r--flags/src/main/java/com/yahoo/vespa/flags/Flags.java24
-rw-r--r--jdisc_core/src/main/java/com/yahoo/jdisc/Request.java7
-rw-r--r--jdisc_core/src/main/java/com/yahoo/jdisc/application/UriPattern.java9
-rw-r--r--jdisc_core/src/test/java/com/yahoo/jdisc/application/UriPatternTestCase.java9
-rw-r--r--jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java77
-rw-r--r--jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscServerConnector.java31
-rw-r--r--jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java13
-rw-r--r--jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java88
-rw-r--r--jrt/pom.xml10
-rw-r--r--logd/src/logd/legacy_forwarder.cpp7
-rw-r--r--logd/src/logd/watcher.cpp2
-rw-r--r--logd/src/tests/legacy_forwarder/legacy_forwarder_test.cpp4
-rw-r--r--logd/src/tests/watcher/watcher_test.cpp3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java40
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java33
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java26
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java21
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java15
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java20
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java18
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java22
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java41
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java15
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java97
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java5
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java4
-rw-r--r--searchlib/abi-spec.json1
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj14
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java2
-rw-r--r--searchlib/src/tests/aggregator/perdocexpr.cpp69
-rw-r--r--searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp68
-rw-r--r--searchlib/src/tests/grouping/grouping_test.cpp33
-rw-r--r--searchlib/src/tests/groupingengine/groupingengine_test.cpp26
-rw-r--r--searchlib/src/vespa/searchlib/aggregation/aggregation.cpp237
-rw-r--r--searchlib/src/vespa/searchlib/aggregation/aggregationresult.h4
-rw-r--r--searchlib/src/vespa/searchlib/aggregation/averageaggregationresult.h2
-rw-r--r--searchlib/src/vespa/searchlib/aggregation/sumaggregationresult.h12
-rw-r--r--searchlib/src/vespa/searchlib/expression/numericresultnode.h7
-rw-r--r--searchlib/src/vespa/searchlib/expression/singleresultnode.h5
-rw-r--r--searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp6
-rw-r--r--searchlib/src/vespa/searchlib/features/dotproductfeature.cpp87
-rw-r--r--searchlib/src/vespa/searchlib/features/dotproductfeature.h8
-rw-r--r--searchlib/src/vespa/searchlib/features/queryfeature.cpp27
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/fake_result.cpp3
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/fake_result.h1
-rw-r--r--security-utils/pom.xml10
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/https/VespaHttpClientBuilder.java109
-rw-r--r--security-utils/src/test/java/com/yahoo/security/tls/https/VespaHttpClientBuilderTest.java39
-rw-r--r--vespa-documentgen-plugin/etc/complex/music3.sd3
-rw-r--r--vespa-documentgen-plugin/src/main/java/com/yahoo/vespa/DocumentGenMojo.java30
-rw-r--r--vespa-documentgen-plugin/src/test/java/com/yahoo/vespa/DocumentGenTest.java12
-rw-r--r--vespajlib/abi-spec.json24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java80
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java72
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java16
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java8
-rw-r--r--vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java10
-rw-r--r--vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java38
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java22
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java20
-rw-r--r--vespalog/src/vespa/log/log.cpp2
-rw-r--r--vespalog/src/vespa/log/log_message.cpp25
204 files changed, 3186 insertions, 1916 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index a0f35dbefe6..6109e5c4aae 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -191,6 +191,8 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
else { // default
dimension = ((ReferenceNode)arg0).reference().arguments().expressions().get(0).toString();
}
+
+ // TODO: Determine the type of the weighted set/vector and use that as value type
return Optional.of(new TensorType.Builder().mapped(dimension).build());
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index f197e2dfe6d..e12cc60b041 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -453,10 +453,9 @@ public class ConvertedModel {
*/
// TODO: determine when this is not necessary!
private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) {
- return node;
- }
- TensorType.Builder typeBuilder = new TensorType.Builder();
+ if (after.equals(before)) return node;
+
+ TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType());
for (TensorType.Dimension dimension : before.dimensions()) {
if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
typeBuilder.indexed(dimension.name(), 1);
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
index 5c96635fd8f..80440ac8eb4 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
@@ -144,7 +144,7 @@ public class RankingExpressionWithTensorTestCase {
@Test
public void requireThatInvalidTensorTypeSpecThrowsException() throws ParseException {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: Failed parsing element 'x' in type spec 'tensor(x)'");
+ exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: A tensor type spec must be on the form tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was 'tensor(x)'. Dimension 'x' is on the wrong format. Examples: tensor(x[]), tensor<float>(name{}, x[10])");
RankProfileSearchFixture f = new RankProfileSearchFixture(
" rank-profile my_profile {\n" +
" constants {\n" +
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java
index 2fcf5809ea5..f53ca15635f 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java
@@ -39,7 +39,7 @@ public class TensorFieldTestCase {
@Test
public void requireThatIllegalTensorTypeSpecThrowsException() throws ParseException {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("Field type: Illegal tensor type spec: Failed parsing element 'invalid' in type spec 'tensor(invalid)'");
+ exception.expectMessage("Field type: Illegal tensor type spec: A tensor type spec must be on the form tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was 'tensor(invalid)'. Dimension 'invalid' is on the wrong format. Examples: tensor(x[]), tensor<float>(name{}, x[10])");
SearchBuilder.createFromString(getSd("field f1 type tensor(invalid) { indexing: attribute }"));
}
diff --git a/configdefinitions/src/vespa/dispatch.def b/configdefinitions/src/vespa/dispatch.def
index 7d5979bcdf1..477a781ebbc 100644
--- a/configdefinitions/src/vespa/dispatch.def
+++ b/configdefinitions/src/vespa/dispatch.def
@@ -40,6 +40,9 @@ minWaitAfterCoverageFactor double default=0
# Maximum wait time for full coverage after minimum coverage is achieved, factored based on time left at minimum coverage
maxWaitAfterCoverageFactor double default=1
+# Number of JRT connection supervisors
+numJrtSupervisors int default=8
+
# The unique key of a search node
node[].key int
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerDB.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerDB.java
index 152fc47d807..3e9783cabe9 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerDB.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerDB.java
@@ -2,8 +2,8 @@
package com.yahoo.vespa.config.server;
import com.yahoo.cloud.config.ConfigserverConfig;
-import com.yahoo.config.model.application.provider.Bundle;
import com.yahoo.config.application.ConfigDefinitionDir;
+import com.yahoo.config.model.application.provider.Bundle;
import com.yahoo.io.IOUtils;
import com.yahoo.log.LogLevel;
import com.yahoo.vespa.defaults.Defaults;
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java
index 4fbda42fdc7..877b2acb86f 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java
@@ -32,7 +32,6 @@ import java.util.Set;
import java.util.stream.Collectors;
import static com.yahoo.config.model.api.container.ContainerServiceType.CONTAINER;
-import static com.yahoo.config.model.api.container.ContainerServiceType.LOGSERVER_CONTAINER;
import static com.yahoo.config.model.api.container.ContainerServiceType.QRSERVER;
/**
@@ -44,12 +43,12 @@ import static com.yahoo.config.model.api.container.ContainerServiceType.QRSERVER
public class ConfigConvergenceChecker extends AbstractComponent {
private static final ApplicationId routingApplicationId = ApplicationId.from("hosted-vespa", "routing", "default");
+ private static final String nodeAdminName = "node-admin";
private static final String statePath = "/state/v1/";
private static final String configSubPath = "config";
private final static Set<String> serviceTypesToCheck = new HashSet<>(Arrays.asList(
CONTAINER.serviceName,
QRSERVER.serviceName,
- LOGSERVER_CONTAINER.serviceName,
"searchnode",
"storagenode",
"distributor"
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java
index 7731e13eac2..3705a0ec145 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java
@@ -1,7 +1,6 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.config.server.application;
-import com.google.common.collect.ImmutableSet;
import com.yahoo.concurrent.ThreadFactoryFactory;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.TenantName;
@@ -17,12 +16,13 @@ import com.yahoo.vespa.curator.transaction.CuratorTransaction;
import org.apache.curator.framework.CuratorFramework;
import org.apache.curator.framework.recipes.cache.PathChildrenCacheEvent;
-import java.util.ArrayList;
import java.util.List;
-import java.util.Optional;
+import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.Logger;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
/**
* The applications of a tenant, backed by ZooKeeper.
@@ -70,40 +70,39 @@ public class TenantApplications {
* @return a list of {@link ApplicationId}s that are active.
*/
public List<ApplicationId> listApplications() {
- try {
- List<String> appNodes = curator.framework().getChildren().forPath(applicationsPath.getAbsolute());
- List<ApplicationId> applicationIds = new ArrayList<>();
- for (String appNode : appNodes) {
- parseApplication(appNode).ifPresent(applicationIds::add);
- }
- return applicationIds;
- } catch (Exception e) {
- throw new RuntimeException(TenantRepository.logPre(tenant)+"Unable to list applications", e);
- }
+ return curator.getChildren(applicationsPath).stream()
+ .flatMap(this::parseApplication)
+ .collect(Collectors.toUnmodifiableList());
}
- private Optional<ApplicationId> parseApplication(String appNode) {
+ // TODO jvenstad: Remove after it has run once everywhere.
+ private Stream<ApplicationId> parseApplication(String appNode) {
try {
- ApplicationId id = ApplicationId.fromSerializedForm(appNode);
- getSessionIdForApplication(id);
- return Optional.of(id);
- } catch (IllegalArgumentException e) {
- log.log(LogLevel.INFO, TenantRepository.logPre(tenant)+"Unable to parse application with id '" + appNode + "', ignoring.");
- return Optional.empty();
+ return Stream.of(ApplicationId.fromSerializedForm(appNode));
+ } catch (IllegalArgumentException __) {
+ log.log(LogLevel.INFO, TenantRepository.logPre(tenant) + "Unable to parse application id from '" +
+ appNode + "'; deleting it as it shouldn't be here.");
+ try {
+ curator.delete(applicationsPath.append(appNode));
+ }
+ catch (Exception e) {
+ log.log(LogLevel.WARNING, TenantRepository.logPre(tenant) + "Failed to clean up stray node '" + appNode + "'!", e);
+ }
+ return Stream.empty();
}
}
/**
- * Register active application and adds it to the repo. If it already exists it is overwritten.
+ * Returns a transaction which writes the given session id as the currently active for the given application.
*
* @param applicationId An {@link ApplicationId} that represents an active application.
* @param sessionId Id of the session containing the application package for this id.
*/
public Transaction createPutApplicationTransaction(ApplicationId applicationId, long sessionId) {
if (listApplications().contains(applicationId)) {
- return new CuratorTransaction(curator).add(CuratorOperations.setData(applicationsPath.append(applicationId.serializedForm()).getAbsolute(), Utf8.toAsciiBytes(sessionId)));
+ return new CuratorTransaction(curator).add(CuratorOperations.setData(applicationPath(applicationId).getAbsolute(), Utf8.toAsciiBytes(sessionId)));
} else {
- return new CuratorTransaction(curator).add(CuratorOperations.create(applicationsPath.append(applicationId.serializedForm()).getAbsolute(), Utf8.toAsciiBytes(sessionId)));
+ return new CuratorTransaction(curator).add(CuratorOperations.create(applicationPath(applicationId).getAbsolute(), Utf8.toAsciiBytes(sessionId)));
}
}
@@ -115,7 +114,7 @@ public class TenantApplications {
* @throws IllegalArgumentException if the application does not exist
*/
public long getSessionIdForApplication(ApplicationId applicationId) {
- String path = applicationsPath.append(applicationId.serializedForm()).getAbsolute();
+ String path = applicationPath(applicationId).getAbsolute();
try {
return Long.parseLong(Utf8.toString(curator.framework().getData().forPath(path)));
} catch (Exception e) {
@@ -124,18 +123,22 @@ public class TenantApplications {
}
/**
- * Returns a transaction which deletes this application
- *
- * @param applicationId an {@link ApplicationId} to delete.
+ * Returns a transaction which deletes this application.
*/
public CuratorTransaction deleteApplication(ApplicationId applicationId) {
- Path path = applicationsPath.append(applicationId.serializedForm());
- return CuratorTransaction.from(CuratorOperations.delete(path.getAbsolute()), curator);
+ return CuratorTransaction.from(CuratorOperations.delete(applicationPath(applicationId).getAbsolute()), curator);
}
/**
- * Closes the application repo. Once a repo has been closed, it should not be used again.
- */
+ * Removes all applications not known to this from the config server state.
+ */
+ public void removeUnusedApplications() {
+ reloadHandler.removeApplicationsExcept(Set.copyOf(listApplications()));
+ }
+
+ /**
+ * Closes the application repo. Once a repo has been closed, it should not be used again.
+ */
public void close() {
directoryCache.close();
}
@@ -169,13 +172,8 @@ public class TenantApplications {
log.log(LogLevel.DEBUG, TenantRepository.logPre(applicationId) + "Application added: " + applicationId);
}
- /**
- * Removes unused applications
- *
- */
- public void removeUnusedApplications() {
- ImmutableSet<ApplicationId> activeApplications = ImmutableSet.copyOf(listApplications());
- reloadHandler.removeApplicationsExcept(activeApplications);
+ private Path applicationPath(ApplicationId id) {
+ return applicationsPath.append(id.serializedForm());
}
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java
index 21716730825..082be2583c2 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java
@@ -89,9 +89,8 @@ public class Deployment implements com.yahoo.config.provision.Deployment {
timeout, clock, true, true, session.getVespaVersion(), isBootstrap);
}
- public Deployment setIgnoreSessionStaleFailure(boolean ignoreSessionStaleFailure) {
+ public void setIgnoreSessionStaleFailure(boolean ignoreSessionStaleFailure) {
this.ignoreSessionStaleFailure = ignoreSessionStaleFailure;
- return this;
}
/** Prepares this. This does nothing if this is already prepared */
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/LocalSession.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/LocalSession.java
index 0f9f8b72de1..0cdf5ebfe95 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/LocalSession.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/LocalSession.java
@@ -83,12 +83,6 @@ public class LocalSession extends Session implements Comparable<LocalSession> {
setStatus(Session.Status.PREPARE);
}
- private Transaction setActive() {
- Transaction transaction = createSetStatusTransaction(Status.ACTIVATE);
- transaction.add(applicationRepo.createPutApplicationTransaction(zooKeeperClient.readApplicationId(), getSessionId()).operations());
- return transaction;
- }
-
private Transaction createSetStatusTransaction(Status status) {
return zooKeeperClient.createWriteStatusTransaction(status);
}
@@ -99,8 +93,10 @@ public class LocalSession extends Session implements Comparable<LocalSession> {
public Transaction createActivateTransaction() {
zooKeeperClient.createActiveWaiter();
- superModelGenerationCounter.increment();
- return setActive();
+ superModelGenerationCounter.increment(); // TODO jvenstad: I hope this counter isn't used for serious things, as it's updated way ahead of activation.
+ Transaction transaction = createSetStatusTransaction(Status.ACTIVATE);
+ transaction.add(applicationRepo.createPutApplicationTransaction(zooKeeperClient.readApplicationId(), getSessionId()).operations());
+ return transaction;
}
public Transaction createDeactivateTransaction() {
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactory.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactory.java
index a3dea83d50c..5527d3060f7 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactory.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactory.java
@@ -17,8 +17,6 @@ public interface SessionFactory {
/**
* Creates a new deployment session from an application package.
*
- *
- *
* @param applicationDirectory a File pointing to an application.
* @param applicationId application id for this new session.
* @param timeoutBudget Timeout for creating session and waiting for other servers.
@@ -29,10 +27,10 @@ public interface SessionFactory {
/**
* Creates a new deployment session from an already existing session.
*
- * @param existingSession The session to use as base
+ * @param existingSession the session to use as base
* @param logger a deploy logger where the deploy log will be written.
- * @param internalRedeploy if this session is for a system internal redeploy not an application package change
- * @param timeoutBudget Timeout for creating session and waiting for other servers.
+ * @param internalRedeploy whether this session is for a system internal redeploy — not an application package change
+ * @param timeoutBudget timeout for creating session and waiting for other servers.
* @return a new session
*/
LocalSession createSessionFromExisting(LocalSession existingSession, DeployLogger logger,
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactoryImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactoryImpl.java
index b79ea720aea..90eeb89dc8e 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactoryImpl.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactoryImpl.java
@@ -194,4 +194,5 @@ public class SessionFactoryImpl implements SessionFactory, LocalSessionLoader {
}
return nonExistingActiveSession;
}
+
}
diff --git a/container-search/src/main/java/com/yahoo/fs4/MapEncoder.java b/container-search/src/main/java/com/yahoo/fs4/MapEncoder.java
index 4245f51ace8..565a4c483c3 100644
--- a/container-search/src/main/java/com/yahoo/fs4/MapEncoder.java
+++ b/container-search/src/main/java/com/yahoo/fs4/MapEncoder.java
@@ -20,8 +20,13 @@ public class MapEncoder {
// TODO: Time to refactor
- private static final String TYPE_SUFFIX = ".type";
- private static final String TENSOR_TYPE = "tensor";
+ private static byte [] getUtf8(Object value) {
+ if (value instanceof Tensor) {
+ return TypedBinaryFormat.encode((Tensor)value);
+ } else {
+ return Utf8.toBytes(value.toString());
+ }
+ }
/**
* Encodes a single value as a complete binary map.
@@ -39,7 +44,7 @@ public class MapEncoder {
utf8 = Utf8.toBytes(key);
buffer.putInt(utf8.length);
buffer.put(utf8);
- utf8 = Utf8.toBytes(value.toString());
+ utf8 = getUtf8(value);
buffer.putInt(utf8.length);
buffer.put(utf8);
@@ -64,7 +69,12 @@ public class MapEncoder {
utf8 = Utf8.toBytes(key);
buffer.putInt(utf8.length);
buffer.put(utf8);
- utf8 = Utf8.toBytes(property.getValue() != null ? property.getValue().toString() : "");
+ Object value = property.getValue();
+ if (value == null) {
+ utf8 = Utf8.toBytes("");
+ } else {
+ utf8 = getUtf8(value);
+ }
buffer.putInt(utf8.length);
buffer.put(utf8);
}
@@ -78,53 +88,21 @@ public class MapEncoder {
*
* Returns the number of maps encoded - 0 or 1
*/
- public static int encodeStringMultiMap(String mapName, Map<String,List<String>> map, ByteBuffer buffer) {
- if (map.isEmpty()) return 0;
-
- byte [] utf8 = Utf8.toBytes(mapName);
- buffer.putInt(utf8.length);
- buffer.put(utf8);
- buffer.putInt(countStringEntries(map));
- for (Map.Entry<String, List<String>> property : map.entrySet()) {
- String key = property.getKey();
- for (Object value : property.getValue()) {
- utf8 = Utf8.toBytes(key);
- buffer.putInt(utf8.length);
- buffer.put(utf8);
- utf8 = Utf8.toBytes(value.toString());
- buffer.putInt(utf8.length);
- buffer.put(utf8);
- }
- }
-
- return 1;
- }
- /**
- * Encodes a multi-map as binary.
- * Does nothing if the value is null.
- *
- * Returns the number of maps encoded - 0 or 1
- */
- public static int encodeObjectMultiMap(String mapName, Map<String,List<Object>> map, ByteBuffer buffer) {
+ public static <T> int encodeMultiMap(String mapName, Map<String,List<T>> map, ByteBuffer buffer) {
if (map.isEmpty()) return 0;
byte[] utf8 = Utf8.toBytes(mapName);
buffer.putInt(utf8.length);
buffer.put(utf8);
- addTensorTypeInfo(map);
- buffer.putInt(countObjectEntries(map));
- for (Map.Entry<String, List<Object>> property : map.entrySet()) {
+ buffer.putInt(countEntries(map));
+ for (Map.Entry<String, List<T>> property : map.entrySet()) {
String key = property.getKey();
for (Object value : property.getValue()) {
utf8 = Utf8.toBytes(key);
buffer.putInt(utf8.length);
buffer.put(utf8);
- if (value instanceof Tensor) {
- utf8 = TypedBinaryFormat.encode((Tensor)value);
- } else {
- utf8 = Utf8.toBytes(value.toString());
- }
+ utf8 = getUtf8(value);
buffer.putInt(utf8.length);
buffer.put(utf8);
}
@@ -133,32 +111,9 @@ public class MapEncoder {
return 1;
}
- private static void addTensorTypeInfo(Map<String, List<Object>> map) {
- Map<String, Tensor> tensorsToTag = new HashMap<>();
- for (Map.Entry<String, List<Object>> entry : map.entrySet()) {
- for (Object value : entry.getValue()) {
- if (value instanceof Tensor) {
- tensorsToTag.put(entry.getKey(), (Tensor)value);
- }
- }
- }
- for (Map.Entry<String, Tensor> entry : tensorsToTag.entrySet()) {
- // Ensure that we only have a single tensor associated with each key
- map.put(entry.getKey(), Arrays.asList(entry.getValue()));
- map.put(entry.getKey() + TYPE_SUFFIX, Arrays.asList(TENSOR_TYPE));
- }
- }
-
- private static int countStringEntries(Map<String, List<String>> value) {
- int entries = 0;
- for (Map.Entry<String, List<String>> property : value.entrySet())
- entries += property.getValue().size();
- return entries;
- }
-
- private static int countObjectEntries(Map<String, List<Object>> value) {
+ private static <T> int countEntries(Map<String, List<T>> value) {
int entries = 0;
- for (Map.Entry<String, List<Object>> property : value.entrySet())
+ for (Map.Entry<String, List<T>> property : value.entrySet())
entries += property.getValue().size();
return entries;
}
diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java
index b97ee87f650..a5007c9cc33 100644
--- a/container-search/src/main/java/com/yahoo/search/Query.java
+++ b/container-search/src/main/java/com/yahoo/search/Query.java
@@ -1055,7 +1055,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
// TODO: Push down
if (presentation.getHighlight() != null) {
- mapCount += MapEncoder.encodeStringMultiMap(Highlight.HIGHLIGHTTERMS, presentation.getHighlight().getHighlightTerms(), buffer);
+ mapCount += MapEncoder.encodeMultiMap(Highlight.HIGHLIGHTTERMS, presentation.getHighlight().getHighlightTerms(), buffer);
}
// TODO: Push down
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/Client.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/Client.java
index cc37df04a62..e54e2187818 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/Client.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/Client.java
@@ -2,8 +2,6 @@
package com.yahoo.search.dispatch.rpc;
import com.yahoo.compress.CompressionType;
-import com.yahoo.compress.Compressor;
-import com.yahoo.prelude.Pong;
import com.yahoo.prelude.fastsearch.FastHit;
import java.util.List;
@@ -15,14 +13,6 @@ import java.util.Optional;
* @author bratseth
*/
interface Client {
-
- void getDocsums(List<FastHit> hits, NodeConnection node, CompressionType compression,
- int uncompressedLength, byte[] compressedSlime, RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver,
- double timeoutSeconds);
-
- void request(String rpcMethod, NodeConnection node, CompressionType compression, int uncompressedLength,
- byte[] compressedPayload, ResponseReceiver responseReceiver, double timeoutSeconds);
-
/** Creates a connection to a particular node in this */
NodeConnection createConnection(String hostname, int port);
@@ -91,6 +81,11 @@ interface Client {
}
interface NodeConnection {
+ void getDocsums(List<FastHit> hits, CompressionType compression, int uncompressedLength, byte[] compressedSlime,
+ RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds);
+
+ void request(String rpcMethod, CompressionType compression, int uncompressedLength, byte[] compressedPayload,
+ ResponseReceiver responseReceiver, double timeoutSeconds);
/** Closes this connection */
void close();
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcClient.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcClient.java
index 2aa01b05955..7e48733106a 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcClient.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcClient.java
@@ -29,31 +29,6 @@ class RpcClient implements Client {
return new RpcNodeConnection(hostname, port, supervisor);
}
- @Override
- public void getDocsums(List<FastHit> hits, NodeConnection node, CompressionType compression, int uncompressedLength,
- byte[] compressedSlime, RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds) {
- Request request = new Request("proton.getDocsums");
- request.parameters().add(new Int8Value(compression.getCode()));
- request.parameters().add(new Int32Value(uncompressedLength));
- request.parameters().add(new DataValue(compressedSlime));
-
- request.setContext(hits);
- RpcNodeConnection rpcNode = ((RpcNodeConnection) node);
- rpcNode.invokeAsync(request, timeoutSeconds, new RpcDocsumResponseWaiter(rpcNode, responseReceiver));
- }
-
- @Override
- public void request(String rpcMethod, NodeConnection node, CompressionType compression, int uncompressedLength, byte[] compressedPayload,
- ResponseReceiver responseReceiver, double timeoutSeconds) {
- Request request = new Request(rpcMethod);
- request.parameters().add(new Int8Value(compression.getCode()));
- request.parameters().add(new Int32Value(uncompressedLength));
- request.parameters().add(new DataValue(compressedPayload));
-
- RpcNodeConnection rpcNode = ((RpcNodeConnection) node);
- rpcNode.invokeAsync(request, timeoutSeconds, new RpcProtobufResponseWaiter(rpcNode, responseReceiver));
- }
-
private static class RpcNodeConnection implements NodeConnection {
// Information about the connected node
@@ -73,7 +48,30 @@ class RpcClient implements Client {
description = "rpc node connection to " + hostname + ":" + port;
}
- public void invokeAsync(Request req, double timeout, RequestWaiter waiter) {
+ @Override
+ public void getDocsums(List<FastHit> hits, CompressionType compression, int uncompressedLength,
+ byte[] compressedSlime, RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds) {
+ Request request = new Request("proton.getDocsums");
+ request.parameters().add(new Int8Value(compression.getCode()));
+ request.parameters().add(new Int32Value(uncompressedLength));
+ request.parameters().add(new DataValue(compressedSlime));
+
+ request.setContext(hits);
+ invokeAsync(request, timeoutSeconds, new RpcDocsumResponseWaiter(this, responseReceiver));
+ }
+
+ @Override
+ public void request(String rpcMethod, CompressionType compression, int uncompressedLength, byte[] compressedPayload,
+ ResponseReceiver responseReceiver, double timeoutSeconds) {
+ Request request = new Request(rpcMethod);
+ request.parameters().add(new Int8Value(compression.getCode()));
+ request.parameters().add(new Int32Value(uncompressedLength));
+ request.parameters().add(new DataValue(compressedPayload));
+
+ invokeAsync(request, timeoutSeconds, new RpcProtobufResponseWaiter(this, responseReceiver));
+ }
+
+ private void invokeAsync(Request req, double timeout, RequestWaiter waiter) {
// TODO: Consider replacing this by a watcher on the target
synchronized(this) { // ensure we have exactly 1 valid connection across threads
if (target == null || ! target.isValid())
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcFillInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcFillInvoker.java
index 760f7486923..aa72823c809 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcFillInvoker.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcFillInvoker.java
@@ -100,7 +100,7 @@ public class RpcFillInvoker extends FillInvoker {
/** Send a getDocsums request to a node. Responses will be added to the given receiver. */
private void sendGetDocsumsRequest(int nodeId, List<FastHit> hits, String summaryClass, CompressionType compression,
Result result, GetDocsumsResponseReceiver responseReceiver) {
- Client.NodeConnection node = resourcePool.nodeConnections().get(nodeId);
+ Client.NodeConnection node = resourcePool.getConnection(nodeId);
if (node == null) {
String error = "Could not fill hits from unknown node " + nodeId;
responseReceiver.receive(Client.ResponseOrError.fromError(error));
@@ -114,9 +114,8 @@ public class RpcFillInvoker extends FillInvoker {
byte[] serializedSlime = BinaryFormat
.encode(toSlime(rankProfile, summaryClass, query.getModel().getDocumentDb(), query.getSessionId(), hits));
double timeoutSeconds = ((double) query.getTimeLeft() - 3.0) / 1000.0;
- Compressor.Compression compressionResult = resourcePool.compressor().compress(compression, serializedSlime);
- resourcePool.client().getDocsums(hits, node, compressionResult.type(), serializedSlime.length, compressionResult.data(),
- responseReceiver, timeoutSeconds);
+ Compressor.Compression compressionResult = resourcePool.compress(query, serializedSlime);
+ node.getDocsums(hits, compressionResult.type(), serializedSlime.length, compressionResult.data(), responseReceiver, timeoutSeconds);
}
static private Slime toSlime(String rankProfile, String summaryClass, String docType, SessionId sessionId, List<FastHit> hits) {
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcPing.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcPing.java
index f3479e2e4a9..c001b51ef11 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcPing.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcPing.java
@@ -52,12 +52,12 @@ public class RpcPing implements Callable<Pong> {
}
private void sendPing(LinkedBlockingQueue<ResponseOrError<ProtobufResponse>> queue) {
- var connection = resourcePool.nodeConnections().get(node.key());
+ var connection = resourcePool.getConnection(node.key());
var ping = SearchProtocol.MonitorRequest.newBuilder().build().toByteArray();
double timeoutSeconds = ((double) clusterMonitor.getConfiguration().getRequestTimeout()) / 1000.0;
Compressor.Compression compressionResult = resourcePool.compressor().compress(PING_COMPRESSION, ping);
- resourcePool.client().request(RPC_METHOD, connection, compressionResult.type(), ping.length, compressionResult.data(),
- rsp -> queue.add(rsp), timeoutSeconds);
+ connection.request(RPC_METHOD, compressionResult.type(), ping.length, compressionResult.data(), rsp -> queue.add(rsp),
+ timeoutSeconds);
}
private Pong decodeReply(ProtobufResponse response) throws InvalidProtocolBufferException {
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcProtobufFillInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcProtobufFillInvoker.java
index 3ec821beba8..cd4ba191a7d 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcProtobufFillInvoker.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcProtobufFillInvoker.java
@@ -66,9 +66,6 @@ public class RpcProtobufFillInvoker extends FillInvoker {
protected void sendFillRequest(Result result, String summaryClass) {
ListMap<Integer, FastHit> hitsByNode = hitsByNode(result);
- CompressionType compression = CompressionType
- .valueOf(result.getQuery().properties().getString(RpcResourcePool.dispatchCompression, "LZ4").toUpperCase());
-
result.getQuery().trace(false, 5, "Sending ", hitsByNode.size(), " summary fetch requests with jrt/protobuf");
outstandingResponses = hitsByNode.size();
@@ -77,7 +74,7 @@ public class RpcProtobufFillInvoker extends FillInvoker {
var builder = ProtobufSerialization.createDocsumRequestBuilder(result.getQuery(), serverId, summaryClass, summaryNeedsQuery);
for (Map.Entry<Integer, List<FastHit>> nodeHits : hitsByNode.entrySet()) {
var payload = ProtobufSerialization.serializeDocsumRequest(builder, nodeHits.getValue());
- sendDocsumsRequest(nodeHits.getKey(), nodeHits.getValue(), payload, compression, result);
+ sendDocsumsRequest(nodeHits.getKey(), nodeHits.getValue(), payload, result);
}
}
@@ -117,8 +114,8 @@ public class RpcProtobufFillInvoker extends FillInvoker {
}
/** Send a docsums request to a node. Responses will be added to the given receiver. */
- private void sendDocsumsRequest(int nodeId, List<FastHit> hits, byte[] payload, CompressionType compression, Result result) {
- Client.NodeConnection node = resourcePool.nodeConnections().get(nodeId);
+ private void sendDocsumsRequest(int nodeId, List<FastHit> hits, byte[] payload, Result result) {
+ Client.NodeConnection node = resourcePool.getConnection(nodeId);
if (node == null) {
String error = "Could not fill hits from unknown node " + nodeId;
receive(Client.ResponseOrError.fromError(error), hits);
@@ -129,9 +126,9 @@ public class RpcProtobufFillInvoker extends FillInvoker {
Query query = result.getQuery();
double timeoutSeconds = ((double) query.getTimeLeft() - 3.0) / 1000.0;
- Compressor.Compression compressionResult = resourcePool.compressor().compress(compression, payload);
- resourcePool.client().request(RPC_METHOD, node, compressionResult.type(), payload.length, compressionResult.data(),
- roe -> receive(roe, hits), timeoutSeconds);
+ Compressor.Compression compressionResult = resourcePool.compress(query, payload);
+ node.request(RPC_METHOD, compressionResult.type(), payload.length, compressionResult.data(), roe -> receive(roe, hits),
+ timeoutSeconds);
}
private void processResponses(Result result, String summaryClass) throws TimeoutException {
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcResourcePool.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcResourcePool.java
index 830ba45ef0f..cccf8dd3693 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcResourcePool.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcResourcePool.java
@@ -2,12 +2,20 @@
package com.yahoo.search.dispatch.rpc;
import com.google.common.collect.ImmutableMap;
+import com.yahoo.compress.CompressionType;
import com.yahoo.compress.Compressor;
+import com.yahoo.compress.Compressor.Compression;
import com.yahoo.processing.request.CompoundName;
+import com.yahoo.search.Query;
import com.yahoo.search.dispatch.FillInvoker;
+import com.yahoo.search.dispatch.rpc.Client.NodeConnection;
import com.yahoo.vespa.config.search.DispatchConfig;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
import java.util.Map;
+import java.util.Random;
/**
* RpcResourcePool constructs {@link FillInvoker} objects that communicate with content nodes over RPC. It also contains
@@ -19,43 +27,70 @@ public class RpcResourcePool {
/** The compression method which will be used with rpc dispatch. "lz4" (default) and "none" is supported. */
public final static CompoundName dispatchCompression = new CompoundName("dispatch.compression");
- private final Compressor compressor = new Compressor();
- private final Client client;
+ private final Compressor compressor = new Compressor(CompressionType.LZ4, 5, 0.95, 32);
+ private final Random random = new Random();
/** Connections to the search nodes this talks to, indexed by node id ("partid") */
- private final ImmutableMap<Integer, Client.NodeConnection> nodeConnections;
+ private final ImmutableMap<Integer, NodeConnectionPool> nodeConnectionPools;
- public RpcResourcePool(Client client, Map<Integer, Client.NodeConnection> nodeConnections) {
- this.client = client;
- this.nodeConnections = ImmutableMap.copyOf(nodeConnections);
+ public RpcResourcePool(Map<Integer, Client.NodeConnection> nodeConnections) {
+ var builder = new ImmutableMap.Builder<Integer, NodeConnectionPool>();
+ nodeConnections.forEach((key, connection) -> builder.put(key, new NodeConnectionPool(Collections.singletonList(connection))));
+ this.nodeConnectionPools = builder.build();
}
public RpcResourcePool(DispatchConfig dispatchConfig) {
- this.client = new RpcClient();
+ var clients = new ArrayList<RpcClient>(dispatchConfig.numJrtSupervisors());
+ for (int i = 0; i < dispatchConfig.numJrtSupervisors(); i++) {
+ clients.add(new RpcClient());
+ }
- // Create node rpc connections, indexed by the node distribution key
- ImmutableMap.Builder<Integer, Client.NodeConnection> nodeConnectionsBuilder = new ImmutableMap.Builder<>();
- for (DispatchConfig.Node node : dispatchConfig.node()) {
- nodeConnectionsBuilder.put(node.key(), client.createConnection(node.host(), node.port()));
+ // Create node rpc connection pools, indexed by the node distribution key
+ var builder = new ImmutableMap.Builder<Integer, NodeConnectionPool>();
+ for (var node : dispatchConfig.node()) {
+ var connections = new ArrayList<Client.NodeConnection>(clients.size());
+ clients.forEach(client -> connections.add(client.createConnection(node.host(), node.port())));
+ builder.put(node.key(), new NodeConnectionPool(connections));
}
- this.nodeConnections = nodeConnectionsBuilder.build();
+ this.nodeConnectionPools = builder.build();
}
public Compressor compressor() {
return compressor;
}
- public Client client() {
- return client;
+ public Compression compress(Query query, byte[] payload) {
+ CompressionType compression = CompressionType.valueOf(query.properties().getString(dispatchCompression, "LZ4").toUpperCase());
+ return compressor.compress(compression, payload);
}
- public ImmutableMap<Integer, Client.NodeConnection> nodeConnections() {
- return nodeConnections;
+ public NodeConnection getConnection(int nodeId) {
+ var pool = nodeConnectionPools.get(nodeId);
+ if (pool == null) {
+ return null;
+ } else {
+ return pool.nextConnection();
+ }
}
public void release() {
- for (Client.NodeConnection nodeConnection : nodeConnections.values()) {
- nodeConnection.close();
+ nodeConnectionPools.values().forEach(NodeConnectionPool::release);
+ }
+
+ private class NodeConnectionPool {
+ private final List<Client.NodeConnection> connections;
+
+ NodeConnectionPool(List<Client.NodeConnection> connections) {
+ this.connections = connections;
+ }
+
+ Client.NodeConnection nextConnection() {
+ int slot = random.nextInt(connections.size());
+ return connections.get(slot);
+ }
+
+ void release() {
+ connections.forEach(Client.NodeConnection::close);
}
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcSearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcSearchInvoker.java
index d70a7d95b63..75e9b06f445 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcSearchInvoker.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcSearchInvoker.java
@@ -46,10 +46,7 @@ public class RpcSearchInvoker extends SearchInvoker implements Client.ResponseRe
protected void sendSearchRequest(Query query) throws IOException {
this.query = query;
- CompressionType compression = CompressionType
- .valueOf(query.properties().getString(RpcResourcePool.dispatchCompression, "LZ4").toUpperCase());
-
- Client.NodeConnection nodeConnection = resourcePool.nodeConnections().get(node.key());
+ Client.NodeConnection nodeConnection = resourcePool.getConnection(node.key());
if (nodeConnection == null) {
responses.add(Client.ResponseOrError.fromError("Could not send search to unknown node " + node.key()));
responseAvailable();
@@ -59,9 +56,8 @@ public class RpcSearchInvoker extends SearchInvoker implements Client.ResponseRe
var payload = ProtobufSerialization.serializeSearchRequest(query, searcher.getServerId());
double timeoutSeconds = ((double) query.getTimeLeft() - 3.0) / 1000.0;
- Compressor.Compression compressionResult = resourcePool.compressor().compress(compression, payload);
- resourcePool.client().request(RPC_METHOD, nodeConnection, compressionResult.type(), payload.length, compressionResult.data(), this,
- timeoutSeconds);
+ Compressor.Compression compressionResult = resourcePool.compress(query, payload);
+ nodeConnection.request(RPC_METHOD, compressionResult.type(), payload.length, compressionResult.data(), this, timeoutSeconds);
}
@Override
diff --git a/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java b/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java
index 4158b0e7476..37a54a82c43 100644
--- a/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java
@@ -76,7 +76,7 @@ public class RankProperties implements Cloneable {
/** Encodes this in a binary internal representation and returns the number of property maps encoded (0 or 1) */
public int encode(ByteBuffer buffer, boolean encodeQueryData) {
if (encodeQueryData) {
- return MapEncoder.encodeObjectMultiMap("rank", properties, buffer);
+ return MapEncoder.encodeMultiMap("rank", properties, buffer);
}
else {
List<Object> sessionId = properties.get(GetDocSumsPacket.sessionIdKey);
diff --git a/container-search/src/test/java/com/yahoo/fs4/test/RankFeaturesTestCase.java b/container-search/src/test/java/com/yahoo/fs4/test/RankFeaturesTestCase.java
index 69ca646dbd5..e8c16e572ae 100644
--- a/container-search/src/test/java/com/yahoo/fs4/test/RankFeaturesTestCase.java
+++ b/container-search/src/test/java/com/yahoo/fs4/test/RankFeaturesTestCase.java
@@ -63,11 +63,10 @@ public class RankFeaturesTestCase {
assertEquals(entries.size(), properties.asMap().size());
Map<String, Object> decodedProperties = decode(type, encode(properties));
- assertEquals(entries.size() * 2, properties.asMap().size()); // tensor type info has been added
- assertEquals(entries.size() * 2, decodedProperties.size());
+ assertEquals(entries.size(), properties.asMap().size());
+ assertEquals(entries.size(), decodedProperties.size());
for (Entry entry : entries) {
assertEquals(entry.tensor, decodedProperties.get(entry.normalizedKey));
- assertEquals("tensor", decodedProperties.get(entry.normalizedKey + ".type"));
}
}
diff --git a/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/FastSearcherTestCase.java b/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/FastSearcherTestCase.java
index f4be2943f5f..04b1d526c67 100644
--- a/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/FastSearcherTestCase.java
+++ b/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/FastSearcherTestCase.java
@@ -145,7 +145,8 @@ public class FastSearcherTestCase {
doFill(fastSearcher, result);
ErrorMessage error = result.hits().getError();
assertEquals("Since we don't actually run summary backends we get this error when the Dispatcher is used",
- "Error response from rpc node connection to host1:0: Connection error", error.getDetailedMessage());
+ "Error response from rpc node connection to hostX:0: Connection error",
+ error.getDetailedMessage().replaceAll("host[12]", "hostX"));
}
{ // direct.summaries due to no summary features
@@ -154,7 +155,8 @@ public class FastSearcherTestCase {
doFill(fastSearcher, result);
ErrorMessage error = result.hits().getError();
assertEquals("Since we don't actually run summary backends we get this error when the Dispatcher is used",
- "Error response from rpc node connection to host1:0: Connection error", error.getDetailedMessage());
+ "Error response from rpc node connection to hostX:0: Connection error",
+ error.getDetailedMessage().replaceAll("host[12]", "hostX"));
}
}
diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/FillTestCase.java b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/FillTestCase.java
index e059008acac..6d1f19eeaf2 100644
--- a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/FillTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/FillTestCase.java
@@ -22,7 +22,6 @@ import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
-
/**
* Tests using a dispatcher to fill a result
*
@@ -38,7 +37,7 @@ public class FillTestCase {
nodes.put(0, client.createConnection("host0", 123));
nodes.put(1, client.createConnection("host1", 123));
nodes.put(2, client.createConnection("host2", 123));
- RpcResourcePool rpcResourcePool = new RpcResourcePool(client, nodes);
+ RpcResourcePool rpcResourcePool = new RpcResourcePool(nodes);
RpcInvokerFactory factory = new RpcInvokerFactory(rpcResourcePool, null, true);
Query query = new Query();
@@ -75,7 +74,7 @@ public class FillTestCase {
nodes.put(0, client.createConnection("host0", 123));
nodes.put(1, client.createConnection("host1", 123));
nodes.put(2, client.createConnection("host2", 123));
- RpcResourcePool rpcResourcePool = new RpcResourcePool(client, nodes);
+ RpcResourcePool rpcResourcePool = new RpcResourcePool(nodes);
RpcInvokerFactory factory = new RpcInvokerFactory(rpcResourcePool, null, true);
Query query = new Query();
@@ -90,7 +89,7 @@ public class FillTestCase {
client.setDocsumReponse("host2", 1, "summaryClass1", map("field1", "s.2.1", "field2", 1));
client.setDocsumReponse("host1", 2, "summaryClass1", new HashMap<>());
client.setDocsumReponse("host2", 3, "summaryClass1", map("field1", "s.2.3", "field2", 3));
- client.setDocsumReponse("host0", 4, "summaryClass1",new HashMap<>());
+ client.setDocsumReponse("host0", 4, "summaryClass1", new HashMap<>());
factory.createFillInvoker(db()).fill(result, "summaryClass1");
@@ -115,7 +114,7 @@ public class FillTestCase {
Map<Integer, Client.NodeConnection> nodes = new HashMap<>();
nodes.put(0, client.createConnection("host0", 123));
- RpcResourcePool rpcResourcePool = new RpcResourcePool(client, nodes);
+ RpcResourcePool rpcResourcePool = new RpcResourcePool(nodes);
RpcInvokerFactory factory = new RpcInvokerFactory(rpcResourcePool, null, true);
Query query = new Query();
@@ -133,7 +132,7 @@ public class FillTestCase {
Map<Integer, Client.NodeConnection> nodes = new HashMap<>();
nodes.put(0, client.createConnection("host0", 123));
- RpcResourcePool rpcResourcePool = new RpcResourcePool(client, nodes);
+ RpcResourcePool rpcResourcePool = new RpcResourcePool(nodes);
RpcInvokerFactory factory = new RpcInvokerFactory(rpcResourcePool, null, true);
Query query = new Query();
@@ -141,7 +140,6 @@ public class FillTestCase {
result.hits().add(createHit(0, 0));
result.hits().add(createHit(1, 1));
-
factory.createFillInvoker(db()).fill(result, "summaryClass1");
assertEquals("Could not fill hits from unknown node 1", result.hits().getError().getDetailedMessage());
@@ -151,8 +149,7 @@ public class FillTestCase {
List<DocsumField> fields = new ArrayList<>();
fields.add(DocsumField.create("field1", "string"));
fields.add(DocsumField.create("field2", "int64"));
- DocsumDefinitionSet docsums = new DocsumDefinitionSet(Collections.singleton(new DocsumDefinition("summaryClass1",
- fields)));
+ DocsumDefinitionSet docsums = new DocsumDefinitionSet(Collections.singleton(new DocsumDefinition("summaryClass1", fields)));
return new DocumentDatabase("default", docsums, Collections.emptySet());
}
diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MockClient.java b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MockClient.java
index 687d3e728c0..3cc3257194c 100644
--- a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MockClient.java
+++ b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MockClient.java
@@ -36,62 +36,6 @@ public class MockClient implements Client {
return new MockNodeConnection(hostname, port);
}
- @Override
- public void getDocsums(List<FastHit> hitsContext, NodeConnection node, CompressionType compression,
- int uncompressedSize, byte[] compressedSlime, RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver,
- double timeoutSeconds) {
- if (malfunctioning) {
- responseReceiver.receive(ResponseOrError.fromError("Malfunctioning"));
- return;
- }
-
- Inspector request = BinaryFormat.decode(compressor.decompress(compressedSlime, compression, uncompressedSize)).get();
- String docsumClass = request.field("class").asString();
- List<Map<String, Object>> docsumsToReturn = new ArrayList<>();
- request.field("gids").traverse((ArrayTraverser)(index, gid) -> {
- GlobalId docId = new GlobalId(gid.asData());
- docsumsToReturn.add(docsums.get(new DocsumKey(node.toString(), docId, docsumClass)));
- });
- Slime responseSlime = new Slime();
- Cursor root = responseSlime.setObject();
- Cursor docsums = root.setArray("docsums");
- for (Map<String, Object> docsumFields : docsumsToReturn) {
- Cursor docsumItem = docsums.addObject();
- Cursor docsum = docsumItem.setObject("docsum");
- for (Map.Entry<String, Object> field : docsumFields.entrySet()) {
- if (field.getValue() instanceof Integer)
- docsum.setLong(field.getKey(), (Integer)field.getValue());
- else if (field.getValue() instanceof String)
- docsum.setString(field.getKey(), (String)field.getValue());
- else
- throw new RuntimeException();
- }
- }
- byte[] slimeBytes = BinaryFormat.encode(responseSlime);
- Compressor.Compression compressionResult = compressor.compress(compression, slimeBytes);
- GetDocsumsResponse response = new GetDocsumsResponse(compressionResult.type().getCode(), slimeBytes.length,
- compressionResult.data(), hitsContext);
- responseReceiver.receive(ResponseOrError.fromResponse(response));
- }
-
- @Override
- public void request(String rpcMethod, NodeConnection node, CompressionType compression, int uncompressedLength, byte[] compressedPayload,
- ResponseReceiver responseReceiver, double timeoutSeconds) {
- if (malfunctioning) {
- responseReceiver.receive(ResponseOrError.fromError("Malfunctioning"));
- return;
- }
-
- if(searchResult == null) {
- responseReceiver.receive(ResponseOrError.fromError("No result defined"));
- return;
- }
- var payload = ProtobufSerialization.serializeResult(searchResult);
- var compressionResult = compressor.compress(compression, payload);
- var response = new ProtobufResponse(compressionResult.type().getCode(), payload.length, compressionResult.data());
- responseReceiver.receive(ResponseOrError.fromResponse(response));
- }
-
public void setDocsumReponse(String nodeId, int docId, String docsumClass, Map<String, Object> docsumValues) {
docsums.put(new DocsumKey(nodeId, globalIdFrom(docId), docsumClass), docsumValues);
}
@@ -100,7 +44,7 @@ public class MockClient implements Client {
return new GlobalId(new IdIdString("", "test", "", String.valueOf(hitId)));
}
- private static class MockNodeConnection implements Client.NodeConnection {
+ private class MockNodeConnection implements Client.NodeConnection {
private final String hostname;
@@ -109,6 +53,61 @@ public class MockClient implements Client {
}
@Override
+ public void getDocsums(List<FastHit> hitsContext, CompressionType compression, int uncompressedSize, byte[] compressedSlime,
+ RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds) {
+ if (malfunctioning) {
+ responseReceiver.receive(ResponseOrError.fromError("Malfunctioning"));
+ return;
+ }
+
+ Inspector request = BinaryFormat.decode(compressor.decompress(compressedSlime, compression, uncompressedSize)).get();
+ String docsumClass = request.field("class").asString();
+ List<Map<String, Object>> docsumsToReturn = new ArrayList<>();
+ request.field("gids").traverse((ArrayTraverser) (index, gid) -> {
+ GlobalId docId = new GlobalId(gid.asData());
+ docsumsToReturn.add(docsums.get(new DocsumKey(toString(), docId, docsumClass)));
+ });
+ Slime responseSlime = new Slime();
+ Cursor root = responseSlime.setObject();
+ Cursor docsums = root.setArray("docsums");
+ for (Map<String, Object> docsumFields : docsumsToReturn) {
+ Cursor docsumItem = docsums.addObject();
+ Cursor docsum = docsumItem.setObject("docsum");
+ for (Map.Entry<String, Object> field : docsumFields.entrySet()) {
+ if (field.getValue() instanceof Integer)
+ docsum.setLong(field.getKey(), (Integer) field.getValue());
+ else if (field.getValue() instanceof String)
+ docsum.setString(field.getKey(), (String) field.getValue());
+ else
+ throw new RuntimeException();
+ }
+ }
+ byte[] slimeBytes = BinaryFormat.encode(responseSlime);
+ Compressor.Compression compressionResult = compressor.compress(compression, slimeBytes);
+ GetDocsumsResponse response = new GetDocsumsResponse(compressionResult.type().getCode(), slimeBytes.length,
+ compressionResult.data(), hitsContext);
+ responseReceiver.receive(ResponseOrError.fromResponse(response));
+ }
+
+ @Override
+ public void request(String rpcMethod, CompressionType compression, int uncompressedLength, byte[] compressedPayload,
+ ResponseReceiver responseReceiver, double timeoutSeconds) {
+ if (malfunctioning) {
+ responseReceiver.receive(ResponseOrError.fromError("Malfunctioning"));
+ return;
+ }
+
+ if(searchResult == null) {
+ responseReceiver.receive(ResponseOrError.fromError("No result defined"));
+ return;
+ }
+ var payload = ProtobufSerialization.serializeResult(searchResult);
+ var compressionResult = compressor.compress(compression, payload);
+ var response = new ProtobufResponse(compressionResult.type().getCode(), payload.length, compressionResult.data());
+ responseReceiver.receive(ResponseOrError.fromResponse(response));
+ }
+
+ @Override
public void close() { }
@Override
diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java
index 64863b9a8a6..d629bd36bb1 100644
--- a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java
+++ b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java
@@ -34,7 +34,7 @@ public class RpcSearchInvokerTest {
var payloadHolder = new AtomicReference<byte[]>();
var lengthHolder = new AtomicInteger();
var mockClient = parameterCollectorClient(compressionTypeHolder, payloadHolder, lengthHolder);
- var mockPool = new RpcResourcePool(mockClient, ImmutableMap.of(7, () -> {}));
+ var mockPool = new RpcResourcePool(ImmutableMap.of(7, mockClient.createConnection("foo", 123)));
@SuppressWarnings("resource")
var invoker = new RpcSearchInvoker(mockSearcher(), new Node(7, "seven", 77, 1), mockPool);
@@ -53,23 +53,26 @@ public class RpcSearchInvokerTest {
AtomicInteger lengthHolder) {
return new Client() {
@Override
- public void request(String rpcMethod, NodeConnection node, CompressionType compression, int uncompressedLength,
- byte[] compressedPayload, ResponseReceiver responseReceiver, double timeoutSeconds) {
- compressionTypeHolder.set(compression);
- payloadHolder.set(compressedPayload);
- lengthHolder.set(uncompressedLength);
- }
+ public NodeConnection createConnection(String hostname, int port) {
+ return new NodeConnection() {
+ @Override
+ public void getDocsums(List<FastHit> hits, CompressionType compression, int uncompressedLength, byte[] compressedSlime,
+ GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds) {
+ fail("Unexpected call");
+ }
- @Override
- public void getDocsums(List<FastHit> hits, NodeConnection node, CompressionType compression, int uncompressedLength,
- byte[] compressedSlime, GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds) {
- fail("Unexpected call");
- }
+ @Override
+ public void request(String rpcMethod, CompressionType compression, int uncompressedLength, byte[] compressedPayload,
+ ResponseReceiver responseReceiver, double timeoutSeconds) {
+ compressionTypeHolder.set(compression);
+ payloadHolder.set(compressedPayload);
+ lengthHolder.set(uncompressedLength);
+ }
- @Override
- public NodeConnection createConnection(String hostname, int port) {
- fail("Unexpected call");
- return null;
+ @Override
+ public void close() {
+ }
+ };
}
};
}
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
index 8eaf4cc08cb..c05c3589a30 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
@@ -77,7 +77,7 @@ public class QueryProfileTypeTestCase {
type.addField(new FieldDescription("myBoolean", FieldType.fromString("boolean", registry)), registry);
type.addField(new FieldDescription("ranking.features.query(myTensor1)", FieldType.fromString("tensor(a{},b{})", registry)), registry);
type.addField(new FieldDescription("ranking.features.query(myTensor2)", FieldType.fromString("tensor(x[2],y[2])", registry)), registry);
- type.addField(new FieldDescription("ranking.features.query(myTensor3)", FieldType.fromString("tensor(x{})",registry)), registry);
+ type.addField(new FieldDescription("ranking.features.query(myTensor3)", FieldType.fromString("tensor<float>(x{})",registry)), registry);
type.addField(new FieldDescription("myQuery", FieldType.fromString("query", registry)), registry);
type.addField(new FieldDescription("myQueryProfile", FieldType.fromString("query-profile", registry),"qp"), registry);
}
@@ -136,7 +136,7 @@ public class QueryProfileTypeTestCase {
assertEquals(true, properties.get("myBoolean"));
assertEquals(Tensor.from(tensorString1), properties.get("ranking.features.query(myTensor1)"));
assertEquals(Tensor.from("tensor(x[2],y[2])", tensorString2), properties.get("ranking.features.query(myTensor2)"));
- assertEquals(Tensor.from("tensor(x{})", tensorString3), properties.get("ranking.features.query(myTensor3)"));
+ assertEquals(Tensor.from("tensor<float>(x{})", tensorString3), properties.get("ranking.features.query(myTensor3)"));
// TODO: assertEquals(..., cprofile.get("myQuery"));
assertEquals("value1", properties.get("myQueryProfile.anyString"));
assertEquals("value1", properties.get("QP.anyString"));
diff --git a/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java
index 3fa7f1ee47e..b5c4166e4de 100644
--- a/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java
@@ -3,6 +3,7 @@ package com.yahoo.search.yql;
import static org.junit.Assert.*;
+import com.yahoo.search.query.QueryTree;
import org.apache.http.client.utils.URIBuilder;
import org.junit.After;
import org.junit.Before;
@@ -29,20 +30,20 @@ public class UserInputTestCase {
@Before
public void setUp() throws Exception {
- searchChain = new Chain<Searcher>(new MinimalQueryInserter());
+ searchChain = new Chain<>(new MinimalQueryInserter());
context = Execution.Context.createContextStub(null);
execution = new Execution(searchChain, context);
}
@After
- public void tearDown() throws Exception {
+ public void tearDown() {
searchChain = null;
context = null;
execution = null;
}
@Test
- public final void testSimpleUserInput() {
+ public void testSimpleUserInput() {
{
URIBuilder builder = searchUri();
builder.setParameter("yql",
@@ -70,7 +71,7 @@ public class UserInputTestCase {
}
@Test
- public final void testRawUserInput() {
+ public void testRawUserInput() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"grammar\": \"raw\"}]userInput(\"nal le\");");
@@ -79,7 +80,7 @@ public class UserInputTestCase {
}
@Test
- public final void testSegmentedUserInput() {
+ public void testSegmentedUserInput() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"grammar\": \"segment\"}]userInput(\"nal le\");");
@@ -88,7 +89,7 @@ public class UserInputTestCase {
}
@Test
- public final void testSegmentedNoiseUserInput() {
+ public void testSegmentedNoiseUserInput() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"grammar\": \"segment\"}]userInput(\"^^^^^^^^\");");
@@ -97,7 +98,7 @@ public class UserInputTestCase {
}
@Test
- public final void testCustomDefaultIndexUserInput() {
+ public void testCustomDefaultIndexUserInput() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"defaultIndex\": \"glompf\"}]userInput(\"nalle\");");
@@ -106,7 +107,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnotatedUserInputStemming() {
+ public void testAnnotatedUserInputStemming() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"stem\": false}]userInput(\"nalle\");");
@@ -117,7 +118,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnotatedUserInputUnrankedTerms() {
+ public void testAnnotatedUserInputUnrankedTerms() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"ranked\": false}]userInput(\"nalle\");");
@@ -128,7 +129,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnotatedUserInputFiltersTerms() {
+ public void testAnnotatedUserInputFiltersTerms() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"filter\": true}]userInput(\"nalle\");");
@@ -139,7 +140,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnotatedUserInputCaseNormalization() {
+ public void testAnnotatedUserInputCaseNormalization() {
URIBuilder builder = searchUri();
builder.setParameter(
"yql",
@@ -151,7 +152,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnotatedUserInputAccentRemoval() {
+ public void testAnnotatedUserInputAccentRemoval() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"accentDrop\": false}]userInput(\"nalle\");");
@@ -162,7 +163,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnotatedUserInputPositionData() {
+ public void testAnnotatedUserInputPositionData() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"usePositionData\": false}]userInput(\"nalle\");");
@@ -173,7 +174,7 @@ public class UserInputTestCase {
}
@Test
- public final void testQueryPropertiesAsStringArguments() {
+ public void testQueryPropertiesAsStringArguments() {
URIBuilder builder = searchUri();
builder.setParameter("nalle", "bamse");
builder.setParameter("meta", "syntactic");
@@ -197,7 +198,7 @@ public class UserInputTestCase {
}
@Test
- public final void testEmptyUserInput() {
+ public void testEmptyUserInput() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where userInput(\"\");");
@@ -205,7 +206,7 @@ public class UserInputTestCase {
}
@Test
- public final void testEmptyUserInputFromQueryProperty() {
+ public void testEmptyUserInputFromQueryProperty() {
URIBuilder builder = searchUri();
builder.setParameter("foo", "");
builder.setParameter("yql",
@@ -214,7 +215,7 @@ public class UserInputTestCase {
}
@Test
- public final void testEmptyQueryProperty() {
+ public void testEmptyQueryProperty() {
URIBuilder builder = searchUri();
builder.setParameter("foo", "");
builder.setParameter("yql", "select * from sources * where bar contains \"a\" and nonEmpty(foo contains @foo);");
@@ -222,7 +223,7 @@ public class UserInputTestCase {
}
@Test
- public final void testEmptyQueryPropertyInsideExpression() {
+ public void testEmptyQueryPropertyInsideExpression() {
URIBuilder builder = searchUri();
builder.setParameter("foo", "");
builder.setParameter("yql",
@@ -231,7 +232,7 @@ public class UserInputTestCase {
}
@Test
- public final void testCompositeWithoutArguments() {
+ public void testCompositeWithoutArguments() {
URIBuilder builder = searchUri();
builder.setParameter("yql", "select * from sources * where bar contains \"a\" and foo contains phrase();");
searchAndAssertNoErrors(builder);
@@ -241,7 +242,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnoyingPlacementOfNonEmpty() {
+ public void testAnnoyingPlacementOfNonEmpty() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where bar contains \"a\" and foo contains nonEmpty(phrase(\"a\", \"b\"));");
@@ -254,7 +255,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAllowEmptyUserInput() {
+ public void testAllowEmptyUserInput() {
URIBuilder builder = searchUri();
builder.setParameter("foo", "");
builder.setParameter("yql", "select * from sources * where [{\"allowEmpty\": true}]userInput(@foo);");
@@ -262,7 +263,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAllowEmptyNullFromQueryParsing() {
+ public void testAllowEmptyNullFromQueryParsing() {
URIBuilder builder = searchUri();
builder.setParameter("foo", ",,,,,,,,");
builder.setParameter("yql", "select * from sources * where [{\"allowEmpty\": true}]userInput(@foo);");
@@ -270,7 +271,7 @@ public class UserInputTestCase {
}
@Test
- public final void testDisallowEmptyNullFromQueryParsing() {
+ public void testDisallowEmptyNullFromQueryParsing() {
URIBuilder builder = searchUri();
builder.setParameter("foo", ",,,,,,,,");
builder.setParameter("yql", "select * from sources * where userInput(@foo);");
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleId.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleId.java
new file mode 100644
index 00000000000..199f233835f
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleId.java
@@ -0,0 +1,116 @@
+package com.yahoo.vespa.hosted.controller.api.integration.user;
+
+import com.yahoo.config.provision.ApplicationName;
+import com.yahoo.config.provision.TenantName;
+import com.yahoo.vespa.hosted.controller.api.role.ApplicationRole;
+import com.yahoo.vespa.hosted.controller.api.role.RoleDefinition;
+import com.yahoo.vespa.hosted.controller.api.role.Role;
+import com.yahoo.vespa.hosted.controller.api.role.Roles;
+import com.yahoo.vespa.hosted.controller.api.role.TenantRole;
+
+import java.util.Objects;
+
+/**
+ * An identifier for a role which users identified by {@link UserId}s can be members of, corresponding to a bound {@link Role}.
+ *
+ * @author jonmv
+ */
+public class RoleId {
+
+ private final String value;
+
+ private RoleId(String value) {
+ if (value.isBlank())
+ throw new IllegalArgumentException("Id value must be non-blank.");
+ this.value = value;
+ }
+
+ public static RoleId fromRole(TenantRole role) {
+ return new RoleId(valueOf(role));
+ }
+
+ public static RoleId fromRole(ApplicationRole role) {
+ return new RoleId(valueOf(role));
+ }
+
+ public static RoleId fromValue(String value) {
+ return new RoleId(value);
+ }
+
+ public String value() {
+ return value;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ RoleId id = (RoleId) o;
+ return Objects.equals(value, id.value);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(value);
+ }
+
+ @Override
+ public String toString() {
+ return "role '" + value + "'";
+ }
+
+ /** Returns the {@link Role} this represent. */
+ public Role toRole(Roles roles) {
+ String[] parts = value.split("\\.");
+ if (parts.length == 2) switch (parts[1]) {
+ case "tenantOwner": return roles.tenantOwner(TenantName.from(parts[0]));
+ case "tenantAdmin": return roles.tenantAdmin(TenantName.from(parts[0]));
+ case "tenantOperator": return roles.tenantOperator(TenantName.from(parts[0]));
+ }
+ if (parts.length == 3) switch (parts[2]) {
+ case "applicationOwner": return roles.applicationOwner(TenantName.from(parts[0]), ApplicationName.from(parts[1]));
+ case "applicationAdmin": return roles.applicationAdmin(TenantName.from(parts[0]), ApplicationName.from(parts[1]));
+ case "applicationOperator": return roles.applicationOperator(TenantName.from(parts[0]), ApplicationName.from(parts[1]));
+ case "applicationDeveloper": return roles.applicationDeveloper(TenantName.from(parts[0]), ApplicationName.from(parts[1]));
+ case "applicationReader": return roles.applicationReader(TenantName.from(parts[0]), ApplicationName.from(parts[1]));
+ }
+ throw new IllegalArgumentException("Malformed or illegal role value '" + value + "'.");
+ }
+
+ private static String valueOf(TenantRole role) {
+ return valueOf(role.tenant()) + "." + valueOf(role.definition());
+ }
+
+ private static String valueOf(ApplicationRole role) {
+ return valueOf(role.tenant()) + "." + valueOf(role.application()) + "." + valueOf(role.definition());
+ }
+
+ private static String valueOf(TenantName tenant) {
+ if (tenant.value().contains("."))
+ throw new IllegalArgumentException("Tenant names may not contain '.'.");
+
+ return tenant.value();
+ }
+
+ private static String valueOf(ApplicationName application) {
+ if (application.value().contains("."))
+ throw new IllegalArgumentException("Application names may not contain '.'.");
+
+ return application.value();
+ }
+
+ private static String valueOf(RoleDefinition role) {
+ switch (role) {
+ case tenantOwner: return "tenantOwner";
+ case tenantAdmin: return "tenantAdmin";
+ case tenantOperator: return "tenantOperator";
+ case applicationOwner: return "applicationOwner";
+ case applicationAdmin: return "applicationAdmin";
+ case applicationOperator: return "applicationOperator";
+ case applicationDeveloper: return "applicationDeveloper";
+ case applicationReader: return "applicationReader";
+ default: throw new IllegalArgumentException("No value defined for role '" + role + "'.");
+ }
+ }
+
+}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserId.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserId.java
new file mode 100644
index 00000000000..3b138d0ce18
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserId.java
@@ -0,0 +1,40 @@
+package com.yahoo.vespa.hosted.controller.api.integration.user;
+
+import java.util.Objects;
+
+/**
+ * An identifier for a user.
+ *
+ * @author jonmv
+ */
+public class UserId {
+
+ private final String value;
+
+ public UserId(String value) {
+ if (value.isBlank())
+ throw new IllegalArgumentException("Id must be non-blank.");
+ this.value = value;
+ }
+
+ public String value() { return value; }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ UserId id = (UserId) o;
+ return Objects.equals(value, id.value);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(value);
+ }
+
+ @Override
+ public String toString() {
+ return "user '" + value + "'";
+ }
+
+}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserManagement.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserManagement.java
new file mode 100644
index 00000000000..c78dcc76854
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserManagement.java
@@ -0,0 +1,33 @@
+package com.yahoo.vespa.hosted.controller.api.integration.user;
+
+import com.yahoo.vespa.hosted.controller.api.role.Role;
+
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * Management of {@link UserId}s and {@link RoleId}s, used for access control with {@link Role}s.
+ *
+ * @author jonmv
+ */
+public interface UserManagement {
+
+ /** Creates the given role, or throws if the role already exists. */
+ void createRole(RoleId role);
+
+ /** Deletes the given role, or throws if it doesn't already exist. */
+ void deleteRole(RoleId role);
+
+ /** Ensures the given users exist, and are part of the given role, or throws if the role does not exist. */
+ void addUsers(RoleId role, Collection<UserId> users);
+
+ /** Ensures none of the given users are part of the given role, or throws if the role does not exist. */
+ void removeUsers(RoleId role, Collection<UserId> users);
+
+ /** Returns all known roles. */
+ List<RoleId> listRoles();
+
+ /** Returns all users in the given role, or throws if the role does not exist. */
+ List<UserId> listUsers(RoleId role);
+
+}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/package-info.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/package-info.java
new file mode 100644
index 00000000000..ca595bab172
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/package-info.java
@@ -0,0 +1,5 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+package com.yahoo.vespa.hosted.controller.api.integration.user;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Action.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Action.java
index 533c28905a9..2d9ef25d1f5 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Action.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Action.java
@@ -1,5 +1,5 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.controller.role;
+package com.yahoo.vespa.hosted.controller.api.role;
import com.yahoo.jdisc.http.HttpRequest;
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/ApplicationRole.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/ApplicationRole.java
new file mode 100644
index 00000000000..cc1e8462580
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/ApplicationRole.java
@@ -0,0 +1,29 @@
+package com.yahoo.vespa.hosted.controller.api.role;
+
+import com.yahoo.config.provision.ApplicationName;
+import com.yahoo.config.provision.SystemName;
+import com.yahoo.config.provision.TenantName;
+
+/**
+ * A {@link Role} with a {@link Context} of a {@link SystemName} a {@link TenantName} and an {@link ApplicationName}.
+ *
+ * @author jonmv
+ */
+public class ApplicationRole extends Role {
+
+ ApplicationRole(RoleDefinition roleDefinition, SystemName system, TenantName tenant, ApplicationName application) {
+ super(roleDefinition, Context.limitedTo(tenant, application, system));
+ }
+
+ /** Returns the {@link TenantName} this is bound to. */
+ public TenantName tenant() { return context.tenant().get(); }
+
+ /** Returns the {@link ApplicationName} this is bound to. */
+ public ApplicationName application() { return context.application().get(); }
+
+ @Override
+ public String toString() {
+ return "role '" + definition() + "' of '" + application() + "' owned by '" + tenant() + "'";
+ }
+
+}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Context.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Context.java
index 71452a3ef20..3ba0367a00c 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Context.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Context.java
@@ -1,5 +1,5 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.controller.role;
+package com.yahoo.vespa.hosted.controller.api.role;
import com.yahoo.config.provision.ApplicationName;
import com.yahoo.config.provision.SystemName;
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/PathGroup.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java
index ef97421119f..edf3f4e8711 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/PathGroup.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java
@@ -1,5 +1,5 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.controller.role;
+package com.yahoo.vespa.hosted.controller.api.role;
import com.yahoo.restapi.Path;
@@ -51,10 +51,10 @@ public enum PathGroup {
Matcher.application,
"/application/v4/tenant/{tenant}/application/{application}/deploying/{*}",
"/application/v4/tenant/{tenant}/application/{application}/instance/{*}",
- "/application/v4/tenant/{tenant}/application/{application}/environment/prod/region/{region}/instance/{instance}/logs",
- "/application/v4/tenant/{tenant}/application/{application}/environment/prod/region/{region}/instance/{instance}/suspended",
- "/application/v4/tenant/{tenant}/application/{application}/environment/prod/region/{region}/instance/{instance}/service/{*}",
- "/application/v4/tenant/{tenant}/application/{application}/environment/prod/region/{region}/instance/{instance}/global-rotation/{*}"),
+ "/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/logs",
+ "/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/suspended",
+ "/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/service/{*}",
+ "/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/global-rotation/{*}"),
/** Path used to restart application nodes. */ // TODO move to the above when everyone is on new pipeline.
applicationRestart(Matcher.tenant,
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Policy.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java
index 6ae68f598f0..970717b14a3 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Policy.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java
@@ -1,5 +1,5 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.controller.role;
+package com.yahoo.vespa.hosted.controller.api.role;
import com.yahoo.config.provision.ApplicationName;
import com.yahoo.config.provision.SystemName;
@@ -39,9 +39,14 @@ public enum Policy {
.in(SystemName.main, SystemName.cd, SystemName.dev)), // TODO SystemName.all()
/** Full access to tenant information and settings. */
- tenantWrite(Privilege.grant(Action.write())
- .on(PathGroup.tenant)
- .in(SystemName.all())),
+ tenantDelete(Privilege.grant(Action.delete)
+ .on(PathGroup.tenant)
+ .in(SystemName.all())),
+
+ /** Full access to tenant information and settings. */
+ tenantUpdate(Privilege.grant(Action.update)
+ .on(PathGroup.tenant)
+ .in(SystemName.all())),
/** Read access to tenant information and settings. */
tenantRead(Privilege.grant(Action.read)
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Privilege.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Privilege.java
index 4c5ad136f56..a53717b25d6 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Privilege.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Privilege.java
@@ -1,5 +1,5 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.controller.role;
+package com.yahoo.vespa.hosted.controller.api.role;
import com.yahoo.config.provision.SystemName;
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Role.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Role.java
new file mode 100644
index 00000000000..86d59b4bbb6
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Role.java
@@ -0,0 +1,48 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.api.role;
+
+import java.net.URI;
+import java.util.Objects;
+
+import static java.util.Objects.requireNonNull;
+
+/**
+ * A role is a combination of a {@link RoleDefinition} and a {@link Context}, which allows evaluation
+ * of access control for a given action on a resource. Create using {@link Roles}.
+ *
+ * @author jonmv
+ */
+public abstract class Role {
+
+ private final RoleDefinition roleDefinition;
+ final Context context;
+
+ Role(RoleDefinition roleDefinition, Context context) {
+ this.roleDefinition = requireNonNull(roleDefinition);
+ this.context = requireNonNull(context);
+ }
+
+ /** Returns the role definition of this bound role. */
+ public RoleDefinition definition() { return roleDefinition; }
+
+ /** Returns whether this role is allowed to perform the given action on the given resource. */
+ public boolean allows(Action action, URI uri) {
+ return roleDefinition.policies().stream().anyMatch(policy -> policy.evaluate(action, uri, context));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Role role = (Role) o;
+ return roleDefinition == role.roleDefinition &&
+ Objects.equals(context, role.context);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(roleDefinition, context);
+ }
+
+}
+
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Role.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java
index d82e4063391..e9c2f7bc643 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Role.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java
@@ -1,21 +1,20 @@
-// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.controller.role;
+package com.yahoo.vespa.hosted.controller.api.role;
-import java.net.URI;
import java.util.EnumSet;
import java.util.Set;
/**
* This declares all tenant roles known to the controller. A role contains one or more {@link Policy}s which decide
- * what actions a member of a role can perform.
+ * what actions a member of a role can perform, given a {@link Context} for the action.
*
- * Optionally, some role definition also inherit all policies from a "lower ranking" role. Read the list of roles
- * from {@code everyone} to {@code tenantAdmin}, in order, to see what policies these roles.
+ * Optionally, some role definitions also inherit all policies from a "lower ranking" role.
+ *
+ * See {@link Role} for roles bound to a context, where policies can be evaluated.
*
* @author mpolden
* @author jonmv
*/
-public enum Role {
+public enum RoleDefinition {
/** Deus ex machina. */
hostedOperator(Policy.operator),
@@ -50,45 +49,52 @@ public enum Role {
Policy.productionDeployment,
Policy.submission),
- /** Tenant admin with full access to all tenant resources, including the ability to create new applications. */
- tenantAdmin(applicationAdmin,
- Policy.applicationCreate,
+ /** Application administrator with the additional ability to delete an application. */
+ applicationOwner(applicationOperator,
+ Policy.applicationDelete),
+
+ /** Tenant operator with admin access to all applications under the tenant, as well as the ability to create applications. */
+ tenantOperator(applicationAdmin,
+ Policy.applicationCreate),
+
+ /** Tenant admin with full access to all tenant resources, except deleting the tenant. */
+ tenantAdmin(tenantOperator,
Policy.applicationDelete,
Policy.manager,
- Policy.tenantWrite),
+ Policy.tenantUpdate),
+
+ /** Tenant admin with full access to all tenant resources. */
+ tenantOwner(tenantAdmin,
+ Policy.tenantDelete),
/** Build and continuous delivery service. */ // TODO replace with buildService, when everyone is on new pipeline.
- tenantPipeline(Policy.submission,
+ tenantPipeline(everyone,
+ Policy.submission,
Policy.deploymentPipeline,
Policy.productionDeployment),
/** Tenant administrator with full access to all child resources. */
- athenzTenantAdmin(Policy.tenantWrite,
+ athenzTenantAdmin(everyone,
Policy.tenantRead,
+ Policy.tenantUpdate,
+ Policy.tenantDelete,
Policy.applicationCreate,
Policy.applicationUpdate,
Policy.applicationDelete,
Policy.applicationOperations,
- Policy.developmentDeployment); // TODO remove, as it is covered by applicationAdmin.
+ Policy.developmentDeployment);
private final Set<Policy> policies;
- Role(Policy... policies) {
+ RoleDefinition(Policy... policies) {
this.policies = EnumSet.copyOf(Set.of(policies));
}
- Role(Role inherited, Policy... policies) {
+ RoleDefinition(RoleDefinition inherited, Policy... policies) {
this.policies = EnumSet.copyOf(Set.of(policies));
this.policies.addAll(inherited.policies);
}
- /**
- * Returns whether this role is allowed to perform action in given role context. Action is allowed if at least one
- * policy evaluates to true.
- */
- public boolean allows(Action action, URI uri, Context context) {
- return policies.stream().anyMatch(policy -> policy.evaluate(action, uri, context));
- }
+ Set<Policy> policies() { return policies; }
}
-
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Roles.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Roles.java
new file mode 100644
index 00000000000..f6149bf6e88
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Roles.java
@@ -0,0 +1,104 @@
+package com.yahoo.vespa.hosted.controller.api.role;
+
+import com.google.inject.Inject;
+import com.yahoo.config.provision.ApplicationName;
+import com.yahoo.config.provision.SystemName;
+import com.yahoo.config.provision.TenantName;
+import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneRegistry;
+
+import java.util.Objects;
+
+/**
+ * Use if you need to create {@link Role}s for its system.
+ *
+ * This also defines the relationship between {@link RoleDefinition}s and their required {@link Context}s.
+ *
+ * @author jonmv
+ */
+public class Roles {
+
+ private final SystemName system;
+
+
+ @Inject
+ public Roles(ZoneRegistry zones) {
+ this(zones.system());
+ }
+
+ /** Creates a Roles which can be used to create bound roles for the given system. */
+ public Roles(SystemName system) {
+ this.system = Objects.requireNonNull(system);
+ }
+
+
+ // General roles.
+ /** Returns a {@link RoleDefinition#hostedOperator} for the current system. */
+ public UnboundRole hostedOperator() {
+ return new UnboundRole(RoleDefinition.hostedOperator, system);
+ }
+
+ /** Returns a {@link RoleDefinition#everyone} for the current system. */
+ public UnboundRole everyone() {
+ return new UnboundRole(RoleDefinition.everyone, system);
+ }
+
+
+ // Athenz based roles.
+ /** Returns a {@link RoleDefinition#athenzTenantAdmin} for the current system and given tenant. */
+ public TenantRole athenzTenantAdmin(TenantName tenant) {
+ return new TenantRole(RoleDefinition.athenzTenantAdmin, system, tenant);
+ }
+
+ /** Returns a {@link RoleDefinition#tenantPipeline} for the current system and given tenant and application. */
+ public ApplicationRole tenantPipeline(TenantName tenant, ApplicationName application) {
+ return new ApplicationRole(RoleDefinition.tenantPipeline, system, tenant, application);
+ }
+
+
+ // Other identity provider based roles.
+ /** Returns a {@link RoleDefinition#tenantOwner} for the current system and given tenant. */
+ public TenantRole tenantOwner(TenantName tenant) {
+ return new TenantRole(RoleDefinition.tenantOwner, system, tenant);
+ }
+
+ /** Returns a {@link RoleDefinition#tenantAdmin} for the current system and given tenant. */
+ public TenantRole tenantAdmin(TenantName tenant) {
+ return new TenantRole(RoleDefinition.tenantAdmin, system, tenant);
+ }
+
+ /** Returns a {@link RoleDefinition#tenantOperator} for the current system and given tenant. */
+ public TenantRole tenantOperator(TenantName tenant) {
+ return new TenantRole(RoleDefinition.tenantOperator, system, tenant);
+ }
+
+ /** Returns a {@link RoleDefinition#applicationOwner} for the current system and given tenant and application. */
+ public ApplicationRole applicationOwner(TenantName tenant, ApplicationName application) {
+ return new ApplicationRole(RoleDefinition.applicationOwner, system, tenant, application);
+ }
+
+ /** Returns a {@link RoleDefinition#applicationAdmin} for the current system and given tenant and application. */
+ public ApplicationRole applicationAdmin(TenantName tenant, ApplicationName application) {
+ return new ApplicationRole(RoleDefinition.applicationAdmin, system, tenant, application);
+ }
+
+ /** Returns a {@link RoleDefinition#applicationOperator} for the current system and given tenant and application. */
+ public ApplicationRole applicationOperator(TenantName tenant, ApplicationName application) {
+ return new ApplicationRole(RoleDefinition.applicationOperator, system, tenant, application);
+ }
+
+ /** Returns a {@link RoleDefinition#applicationDeveloper} for the current system and given tenant and application. */
+ public ApplicationRole applicationDeveloper(TenantName tenant, ApplicationName application) {
+ return new ApplicationRole(RoleDefinition.applicationDeveloper, system, tenant, application);
+ }
+
+ /** Returns a {@link RoleDefinition#applicationReader} for the current system and given tenant and application. */
+ public ApplicationRole applicationReader(TenantName tenant, ApplicationName application) {
+ return new ApplicationRole(RoleDefinition.applicationReader, system, tenant, application);
+ }
+
+ /** Returns a {@link RoleDefinition#buildService} for the current system and given tenant and application. */
+ public ApplicationRole buildService(TenantName tenant, ApplicationName application) {
+ return new ApplicationRole(RoleDefinition.buildService, system, tenant, application);
+ }
+
+}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/SecurityContext.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/SecurityContext.java
new file mode 100644
index 00000000000..41444258a68
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/SecurityContext.java
@@ -0,0 +1,51 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.api.role;
+
+import java.security.Principal;
+import java.util.Objects;
+import java.util.Set;
+
+import static java.util.Objects.requireNonNull;
+
+public class SecurityContext {
+
+ public static final String ATTRIBUTE_NAME = SecurityContext.class.getName();
+
+ private final Principal principal;
+ private final Set<Role> roles;
+
+ public SecurityContext(Principal principal, Set<Role> roles) {
+ this.principal = requireNonNull(principal);
+ this.roles = Set.copyOf(roles);
+ }
+
+ public Principal principal() {
+ return principal;
+ }
+
+ public Set<Role> roles() {
+ return roles;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ SecurityContext that = (SecurityContext) o;
+ return Objects.equals(principal, that.principal) &&
+ Objects.equals(roles, that.roles);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(principal, roles);
+ }
+
+ @Override
+ public String toString() {
+ return "SecurityContext{" +
+ "principal=" + principal +
+ ", roles=" + roles +
+ '}';
+ }
+}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/TenantRole.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/TenantRole.java
new file mode 100644
index 00000000000..134628ec3a3
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/TenantRole.java
@@ -0,0 +1,25 @@
+package com.yahoo.vespa.hosted.controller.api.role;
+
+import com.yahoo.config.provision.SystemName;
+import com.yahoo.config.provision.TenantName;
+
+/**
+ * A {@link Role} with a {@link Context} of a {@link SystemName} and a {@link TenantName}.
+ *
+ * @author jonmv
+ */
+public class TenantRole extends Role {
+
+ TenantRole(RoleDefinition roleDefinition, SystemName system, TenantName tenant) {
+ super(roleDefinition, Context.limitedTo(tenant, system));
+ }
+
+ /** Returns the {@link TenantName} this is bound to. */
+ public TenantName tenant() { return context.tenant().get(); }
+
+ @Override
+ public String toString() {
+ return "role '" + definition() + "' of '" + tenant() + "'";
+ }
+
+}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/UnboundRole.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/UnboundRole.java
new file mode 100644
index 00000000000..eb8319b2012
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/UnboundRole.java
@@ -0,0 +1,21 @@
+package com.yahoo.vespa.hosted.controller.api.role;
+
+import com.yahoo.config.provision.SystemName;
+
+/**
+ * A {@link Role} with a {@link Context} of only a {@link SystemName}.
+ *
+ * @author jonmv
+ */
+public class UnboundRole extends Role {
+
+ UnboundRole(RoleDefinition roleDefinition, SystemName system) {
+ super(roleDefinition, Context.unlimitedIn(system));
+ }
+
+ @Override
+ public String toString() {
+ return "role '" + definition() + "'";
+ }
+
+}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/package-info.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/package-info.java
new file mode 100644
index 00000000000..a7f70d6fe3c
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/package-info.java
@@ -0,0 +1,5 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+package com.yahoo.vespa.hosted.controller.api.role;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleIdTest.java b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleIdTest.java
new file mode 100644
index 00000000000..609646eb672
--- /dev/null
+++ b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleIdTest.java
@@ -0,0 +1,74 @@
+package com.yahoo.vespa.hosted.controller.api.integration.user;
+
+import com.yahoo.config.provision.ApplicationName;
+import com.yahoo.config.provision.SystemName;
+import com.yahoo.config.provision.TenantName;
+import com.yahoo.vespa.hosted.controller.api.role.ApplicationRole;
+import com.yahoo.vespa.hosted.controller.api.role.Roles;
+import com.yahoo.vespa.hosted.controller.api.role.TenantRole;
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author jonmv
+ */
+public class RoleIdTest {
+
+ @Test
+ public void testSerialization() {
+ Roles roles = new Roles(SystemName.main);
+
+ TenantName tenant = TenantName.from("my-tenant");
+ for (TenantRole role : List.of(roles.tenantOwner(tenant),
+ roles.tenantAdmin(tenant),
+ roles.tenantOperator(tenant)))
+ assertEquals(role, RoleId.fromRole(role).toRole(roles));
+
+ ApplicationName application = ApplicationName.from("my-application");
+ for (ApplicationRole role : List.of(roles.applicationOwner(tenant, application),
+ roles.applicationAdmin(tenant, application),
+ roles.applicationOperator(tenant, application),
+ roles.applicationDeveloper(tenant, application),
+ roles.applicationReader(tenant, application)))
+ assertEquals(role, RoleId.fromRole(role).toRole(roles));
+
+ assertEquals(roles.tenantOperator(tenant),
+ RoleId.fromValue("my-tenant.tenantOperator").toRole(roles));
+ assertEquals(roles.applicationReader(tenant, application),
+ RoleId.fromValue("my-tenant.my-application.applicationReader").toRole(roles));
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void illegalTenantName() {
+ RoleId.fromRole(new Roles(SystemName.main).tenantAdmin(TenantName.from("my.tenant")));
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void illegalApplicationName() {
+ RoleId.fromRole(new Roles(SystemName.main).applicationOperator(TenantName.from("my-tenant"), ApplicationName.from("my.app")));
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void illegalRole() {
+ RoleId.fromRole(new Roles(SystemName.main).tenantPipeline(TenantName.from("my-tenant"), ApplicationName.from("my-app")));
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void illegalRoleValue() {
+ RoleId.fromValue("my-tenant.awesomePerson").toRole(new Roles(SystemName.cd));
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void illegalCombination() {
+ RoleId.fromValue("my-tenant.my-application.tenantOwner").toRole(new Roles(SystemName.cd));
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void illegalValue() {
+ RoleId.fromValue("hostedOperator").toRole(new Roles(SystemName.Public));
+ }
+
+}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/role/PathGroupTest.java b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/role/PathGroupTest.java
index b4a3e674594..9d76d055877 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/role/PathGroupTest.java
+++ b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/role/PathGroupTest.java
@@ -1,9 +1,8 @@
-package com.yahoo.vespa.hosted.controller.role;
+package com.yahoo.vespa.hosted.controller.api.role;
import org.junit.Test;
import java.util.HashSet;
-import java.util.LinkedHashSet;
import java.util.Set;
import java.util.regex.Pattern;
diff --git a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/role/RoleTest.java b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/role/RoleTest.java
new file mode 100644
index 00000000000..1badd157b1b
--- /dev/null
+++ b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/role/RoleTest.java
@@ -0,0 +1,54 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.api.role;
+
+import com.yahoo.config.provision.ApplicationName;
+import com.yahoo.config.provision.SystemName;
+import com.yahoo.config.provision.TenantName;
+import org.junit.Test;
+
+import java.net.URI;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author mpolden
+ */
+public class RoleTest {
+
+ @Test
+ public void operator_membership() {
+ Role role = new Roles(SystemName.main).hostedOperator();
+
+ // Operator actions
+ assertFalse(role.allows(Action.create, URI.create("/not/explicitly/defined")));
+ assertTrue(role.allows(Action.create, URI.create("/controller/v1/foo")));
+ assertTrue(role.allows(Action.update, URI.create("/os/v1/bar")));
+ assertTrue(role.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1")));
+ assertTrue(role.allows(Action.update, URI.create("/application/v4/tenant/t2/application/a2")));
+ }
+
+ @Test
+ public void tenant_membership() {
+ Role role = new Roles(SystemName.main).athenzTenantAdmin(TenantName.from("t1"));
+ assertFalse(role.allows(Action.create, URI.create("/not/explicitly/defined")));
+ assertFalse("Deny access to operator API", role.allows(Action.create, URI.create("/controller/v1/foo")));
+ assertFalse("Deny access to other tenant and app", role.allows(Action.update, URI.create("/application/v4/tenant/t2/application/a2")));
+ assertTrue(role.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1")));
+
+ Role publicSystem = new Roles(SystemName.vaas).athenzTenantAdmin(TenantName.from("t1"));
+ assertFalse(publicSystem.allows(Action.read, URI.create("/controller/v1/foo")));
+ assertTrue(publicSystem.allows(Action.read, URI.create("/badge/v1/badge")));
+ assertTrue(publicSystem.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1")));
+ }
+
+ @Test
+ public void build_service_membership() {
+ Role role = new Roles(SystemName.vaas).tenantPipeline(TenantName.from("t1"), ApplicationName.from("a1"));
+ assertFalse(role.allows(Action.create, URI.create("/not/explicitly/defined")));
+ assertFalse(role.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1")));
+ assertTrue(role.allows(Action.create, URI.create("/application/v4/tenant/t1/application/a1/jobreport")));
+ assertFalse("No global read access", role.allows(Action.read, URI.create("/controller/v1/foo")));
+ }
+
+}
diff --git a/controller-server/pom.xml b/controller-server/pom.xml
index c4cb66de3ec..f22142db727 100644
--- a/controller-server/pom.xml
+++ b/controller-server/pom.xml
@@ -100,6 +100,13 @@
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>flags</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+
<!-- compile -->
<dependency>
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java
index 1d685895914..b6993fbc421 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java
@@ -13,6 +13,9 @@ import com.yahoo.vespa.athenz.api.AthenzDomain;
import com.yahoo.vespa.athenz.api.AthenzPrincipal;
import com.yahoo.vespa.athenz.api.AthenzUser;
import com.yahoo.vespa.curator.Lock;
+import com.yahoo.vespa.flags.BooleanFlag;
+import com.yahoo.vespa.flags.FetchVector;
+import com.yahoo.vespa.flags.Flags;
import com.yahoo.vespa.hosted.controller.api.ActivateResult;
import com.yahoo.vespa.hosted.controller.api.application.v4.model.DeployOptions;
import com.yahoo.vespa.hosted.controller.api.application.v4.model.EndpointStatus;
@@ -42,6 +45,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId;
import com.yahoo.vespa.hosted.controller.application.ApplicationPackage;
import com.yahoo.vespa.hosted.controller.application.Deployment;
import com.yahoo.vespa.hosted.controller.application.DeploymentMetrics;
+import com.yahoo.vespa.hosted.controller.application.GlobalDnsName;
import com.yahoo.vespa.hosted.controller.application.JobList;
import com.yahoo.vespa.hosted.controller.application.JobStatus;
import com.yahoo.vespa.hosted.controller.application.JobStatus.JobRun;
@@ -112,6 +116,7 @@ public class ApplicationController {
private final ConfigServer configServer;
private final RoutingGenerator routingGenerator;
private final Clock clock;
+ private final BooleanFlag redirectLegacyDnsFlag;
private final DeploymentTrigger deploymentTrigger;
@@ -127,6 +132,7 @@ public class ApplicationController {
this.configServer = configServer;
this.routingGenerator = routingGenerator;
this.clock = clock;
+ this.redirectLegacyDnsFlag = Flags.REDIRECT_LEGACY_DNS_NAMES.bindTo(controller.flagSource());
this.artifactRepository = artifactRepository;
this.applicationStore = applicationStore;
@@ -231,14 +237,14 @@ public class ApplicationController {
com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId.validate(id.application().value());
Optional<Tenant> tenant = controller.tenants().get(id.tenant());
- if ( ! tenant.isPresent())
+ if (tenant.isEmpty())
throw new IllegalArgumentException("Could not create '" + id + "': This tenant does not exist");
if (get(id).isPresent())
throw new IllegalArgumentException("Could not create '" + id + "': Application already exists");
if (get(dashToUnderscore(id)).isPresent()) // VESPA-1945
throw new IllegalArgumentException("Could not create '" + id + "': Application " + dashToUnderscore(id) + " already exists");
if (tenant.get().type() != Tenant.Type.user) {
- if ( ! credentials.isPresent())
+ if (credentials.isEmpty())
throw new IllegalArgumentException("Could not create '" + id + "': No credentials provided");
if (id.instance().isDefault()) // Only store the application permits for non-user applications.
@@ -269,7 +275,7 @@ public class ApplicationController {
throw new IllegalArgumentException("'" + applicationId + "' is a tester application!");
Tenant tenant = controller.tenants().require(applicationId.tenant());
- if (tenant.type() == Tenant.Type.user && ! get(applicationId).isPresent())
+ if (tenant.type() == Tenant.Type.user && get(applicationId).isEmpty())
createApplication(applicationId, Optional.empty());
try (Lock deploymentLock = lockForDeployment(applicationId, zone)) {
@@ -292,15 +298,15 @@ public class ApplicationController {
() -> new IllegalArgumentException("Application package must be given when deploying to " + zone));
platformVersion = options.vespaVersion.map(Version::new).orElse(applicationPackage.deploymentSpec().majorVersion()
.flatMap(this::lastCompatibleVersion)
- .orElse(controller.systemVersion()));
+ .orElseGet(controller::systemVersion));
}
else {
JobType jobType = JobType.from(controller.system(), zone)
.orElseThrow(() -> new IllegalArgumentException("No job is known for " + zone + "."));
Optional<JobStatus> job = Optional.ofNullable(application.get().deploymentJobs().jobStatus().get(jobType));
- if ( ! job.isPresent()
- || ! job.get().lastTriggered().isPresent()
- || job.get().lastCompleted().isPresent() && job.get().lastCompleted().get().at().isAfter(job.get().lastTriggered().get().at()))
+ if ( job.isEmpty()
+ || job.get().lastTriggered().isEmpty()
+ || job.get().lastCompleted().isPresent() && job.get().lastCompleted().get().at().isAfter(job.get().lastTriggered().get().at()))
return unexpectedDeployment(applicationId, zone);
JobRun triggered = job.get().lastTriggered().get();
platformVersion = preferOldestVersion ? triggered.sourcePlatform().orElse(triggered.platform())
@@ -382,7 +388,7 @@ public class ApplicationController {
application = withoutUnreferencedDeploymentJobs(application);
store(application);
- return(application);
+ return application;
}
/** Deploy a system application to given zone */
@@ -432,20 +438,28 @@ public class ApplicationController {
application = application.with(rotation.id());
store(application); // store assigned rotation even if deployment fails
- registerRotationInDns(rotation, application.get().globalDnsName(controller.system()).get().dnsName());
- registerRotationInDns(rotation, application.get().globalDnsName(controller.system()).get().secureDnsName());
- registerRotationInDns(rotation, application.get().globalDnsName(controller.system()).get().oathDnsName());
+ GlobalDnsName dnsName = application.get().globalDnsName(controller.system())
+ .orElseThrow(() -> new IllegalStateException("Expected rotation to be assigned"));
+ boolean redirectLegacyDns = redirectLegacyDnsFlag.with(FetchVector.Dimension.APPLICATION_ID, application.get().id().serializedForm())
+ .value();
+ registerCname(dnsName.oathDnsName(), rotation.name());
+ if (redirectLegacyDns) {
+ registerCname(dnsName.dnsName(), dnsName.oathDnsName());
+ registerCname(dnsName.secureDnsName(), dnsName.oathDnsName());
+ } else {
+ registerCname(dnsName.dnsName(), rotation.name());
+ registerCname(dnsName.secureDnsName(), rotation.name());
+ }
}
}
return application;
}
- private ActivateResult unexpectedDeployment(ApplicationId applicationId, ZoneId zone) {
-
+ private ActivateResult unexpectedDeployment(ApplicationId application, ZoneId zone) {
Log logEntry = new Log();
logEntry.level = "WARNING";
logEntry.time = clock.instant().toEpochMilli();
- logEntry.message = "Ignoring deployment of " + require(applicationId) + " to " + zone +
+ logEntry.message = "Ignoring deployment of application '" + application + "' to " + zone +
" as a deployment is not currently expected";
PrepareResponse prepareResponse = new PrepareResponse();
prepareResponse.log = Collections.singletonList(logEntry);
@@ -495,24 +509,22 @@ public class ApplicationController {
options.deployCurrentVersion);
}
- /** Register a DNS name for rotation */
- private void registerRotationInDns(Rotation rotation, String dnsName) {
+ /** Register a CNAME record in DNS */
+ private void registerCname(String name, String targetName) {
try {
-
- RecordData rotationName = RecordData.fqdn(rotation.name());
- List<Record> records = nameService.findRecords(Record.Type.CNAME, RecordName.from(dnsName));
+ RecordData data = RecordData.fqdn(targetName);
+ List<Record> records = nameService.findRecords(Record.Type.CNAME, RecordName.from(name));
records.forEach(record -> {
- // Ensure that the existing record points to the correct rotation
- if ( ! record.data().equals(rotationName)) {
- nameService.updateRecord(record, rotationName);
- log.info("Updated mapping for record '" + record + "': '" + dnsName
- + "' -> '" + rotation.name() + "'");
+ // Ensure that the existing record points to the correct target
+ if ( ! record.data().equals(data)) {
+ log.info("Updating mapping for record '" + record + "': '" + name
+ + "' -> '" + data.asString() + "'");
+ nameService.updateRecord(record, data);
}
});
-
if (records.isEmpty()) {
- Record record = nameService.createCname(RecordName.from(dnsName), rotationName);
- log.info("Registered mapping as record '" + record + "'");
+ Record record = nameService.createCname(RecordName.from(name), data);
+ log.info("Registered mapping as record '" + record + "'");
}
} catch (RuntimeException e) {
log.log(Level.WARNING, "Failed to register CNAME", e);
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java
index 6e59c384485..7754286ba9e 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java
@@ -9,8 +9,7 @@ import com.yahoo.config.provision.CloudName;
import com.yahoo.config.provision.HostName;
import com.yahoo.config.provision.SystemName;
import com.yahoo.vespa.curator.Lock;
-import com.yahoo.vespa.hosted.controller.api.identifiers.Property;
-import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId;
+import com.yahoo.vespa.flags.FlagSource;
import com.yahoo.vespa.hosted.controller.api.integration.BuildService;
import com.yahoo.vespa.hosted.controller.api.integration.MetricsService;
import com.yahoo.vespa.hosted.controller.api.integration.RunDataStore;
@@ -76,6 +75,7 @@ public class Controller extends AbstractComponent {
private final Chef chef;
private final Mailer mailer;
private final AuditLogger auditLogger;
+ private final FlagSource flagSource;
/**
* Creates a controller
@@ -88,11 +88,11 @@ public class Controller extends AbstractComponent {
NameService nameService, RoutingGenerator routingGenerator, Chef chef,
AccessControl accessControl,
ArtifactRepository artifactRepository, ApplicationStore applicationStore, TesterCloud testerCloud,
- BuildService buildService, RunDataStore runDataStore, Mailer mailer) {
+ BuildService buildService, RunDataStore runDataStore, Mailer mailer, FlagSource flagSource) {
this(curator, rotationsConfig, gitHub, zoneRegistry,
configServer, metricsService, nameService, routingGenerator, chef,
Clock.systemUTC(), accessControl, artifactRepository, applicationStore, testerCloud,
- buildService, runDataStore, com.yahoo.net.HostName::getLocalhost, mailer);
+ buildService, runDataStore, com.yahoo.net.HostName::getLocalhost, mailer, flagSource);
}
public Controller(CuratorDb curator, RotationsConfig rotationsConfig, GitHub gitHub,
@@ -102,7 +102,7 @@ public class Controller extends AbstractComponent {
AccessControl accessControl,
ArtifactRepository artifactRepository, ApplicationStore applicationStore, TesterCloud testerCloud,
BuildService buildService, RunDataStore runDataStore, Supplier<String> hostnameSupplier,
- Mailer mailer) {
+ Mailer mailer, FlagSource flagSource) {
this.hostnameSupplier = Objects.requireNonNull(hostnameSupplier, "HostnameSupplier cannot be null");
this.curator = Objects.requireNonNull(curator, "Curator cannot be null");
@@ -113,6 +113,7 @@ public class Controller extends AbstractComponent {
this.chef = Objects.requireNonNull(chef, "Chef cannot be null");
this.clock = Objects.requireNonNull(clock, "Clock cannot be null");
this.mailer = Objects.requireNonNull(mailer, "Mailer cannot be null");
+ this.flagSource = Objects.requireNonNull(flagSource, "FlagSource cannot be null");
jobController = new JobController(this, runDataStore, Objects.requireNonNull(testerCloud));
applicationController = new ApplicationController(this, curator, accessControl,
@@ -123,7 +124,8 @@ public class Controller extends AbstractComponent {
Objects.requireNonNull(applicationStore, "ApplicationStore cannot be null"),
Objects.requireNonNull(routingGenerator, "RoutingGenerator cannot be null"),
Objects.requireNonNull(buildService, "BuildService cannot be null"),
- clock);
+ clock
+ );
tenantController = new TenantController(this, curator, accessControl);
auditLogger = new AuditLogger(curator, clock);
@@ -146,6 +148,11 @@ public class Controller extends AbstractComponent {
return mailer;
}
+ /** Provides access to the feature flags of this */
+ public FlagSource flagSource() {
+ return flagSource;
+ }
+
public Clock clock() { return clock; }
public ZoneRegistry zoneRegistry() { return zoneRegistry; }
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java
index e8b3e334631..d1a6e39a1dd 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java
@@ -61,7 +61,7 @@ public class TenantController {
.collect(Collectors.toList());
}
- /** Returns the lsit of tenants accessible to the given user. */
+ /** Returns the list of tenants accessible to the given user. */
public List<Tenant> asList(Credentials credentials) {
return accessControl.accessibleTenants(asList(), credentials);
}
@@ -147,10 +147,11 @@ public class TenantController {
}
private void requireNonExistent(TenantName name) {
- if (get(name).isPresent() ||
+ if ( "hosted-vespa".equals(name.value())
+ || get(name).isPresent()
// Underscores are allowed in existing tenant names, but tenants with - and _ cannot co-exist. E.g.
// my-tenant cannot be created if my_tenant exists.
- get(name.value().replace('-', '_')).isPresent()) {
+ || get(name.value().replace('-', '_')).isPresent()) {
throw new IllegalArgumentException("Tenant '" + name + "' already exists");
}
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/GlobalDnsName.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/GlobalDnsName.java
index 0254bf2fd38..ae638beed5c 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/GlobalDnsName.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/GlobalDnsName.java
@@ -9,7 +9,7 @@ import java.net.URI;
import java.util.Objects;
/**
- * Represents an application's global rotation.
+ * Represents names for an application's global rotation.
*
* @author mpolden
*/
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainer.java
index 2fe6af02480..7693f224b56 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainer.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainer.java
@@ -72,7 +72,7 @@ public class DnsMaintainer extends Maintainer {
private Optional<Rotation> rotationToCheckOf(Collection<Rotation> rotations) {
if (rotations.isEmpty()) return Optional.empty();
List<Rotation> rotationList = new ArrayList<>(rotations);
- int index = rotationIndex.getAndUpdate((i)-> {
+ int index = rotationIndex.getAndUpdate((i) -> {
if (i < rotationList.size() - 1) {
return ++i;
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
index fb27247c48a..8f58827d33a 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
@@ -367,6 +367,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
}
private void toSlime(Cursor object, Application application, HttpRequest request) {
+ object.setString("tenant", application.id().tenant().value());
object.setString("application", application.id().application().value());
object.setString("instance", application.id().instance().value());
object.setString("deployments", withPath("/application/v4" +
@@ -456,21 +457,22 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
for (Deployment deployment : deployments) {
Cursor deploymentObject = instancesArray.addObject();
- deploymentObject.setString("environment", deployment.zone().environment().value());
- deploymentObject.setString("region", deployment.zone().region().value());
- deploymentObject.setString("instance", application.id().instance().value()); // pointless
if (application.rotation().isPresent() && deployment.zone().environment() == Environment.prod) {
toSlime(application.rotationStatus(deployment), deploymentObject);
}
if (recurseOverDeployments(request)) // List full deployment information when recursive.
toSlime(deploymentObject, new DeploymentId(application.id(), deployment.zone()), deployment, request);
- else
+ else {
+ deploymentObject.setString("environment", deployment.zone().environment().value());
+ deploymentObject.setString("region", deployment.zone().region().value());
+ deploymentObject.setString("instance", application.id().instance().value()); // pointless
deploymentObject.setString("url", withPath(request.getUri().getPath() +
"/environment/" + deployment.zone().environment().value() +
"/region/" + deployment.zone().region().value() +
"/instance/" + application.id().instance().value(),
request.getUri()).toString());
+ }
}
// Metrics
@@ -516,6 +518,12 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
private void toSlime(Cursor response, DeploymentId deploymentId, Deployment deployment, HttpRequest request) {
+ response.setString("tenant", deploymentId.applicationId().tenant().value());
+ response.setString("application", deploymentId.applicationId().application().value());
+ response.setString("instance", deploymentId.applicationId().instance().value()); // pointless
+ response.setString("environment", deploymentId.zoneId().environment().value());
+ response.setString("region", deploymentId.zoneId().region().value());
+
Cursor serviceUrlArray = response.setArray("serviceUrls");
controller.applications().getDeploymentEndpoints(deploymentId)
.ifPresent(endpoints -> endpoints.forEach(endpoint -> serviceUrlArray.addString(endpoint.toString())));
@@ -1154,6 +1162,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
}
private void toSlime(Application application, Cursor object, HttpRequest request) {
+ object.setString("tenant", application.id().tenant().value());
object.setString("application", application.id().application().value());
object.setString("instance", application.id().instance().value());
object.setString("url", withPath("/application/v4/tenant/" + application.id().tenant().value() +
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilter.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilter.java
new file mode 100644
index 00000000000..f25deb11a52
--- /dev/null
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilter.java
@@ -0,0 +1,115 @@
+package com.yahoo.vespa.hosted.controller.restapi.filter;
+
+import com.google.inject.Inject;
+import com.yahoo.config.provision.ApplicationName;
+import com.yahoo.config.provision.TenantName;
+import com.yahoo.jdisc.Response;
+import com.yahoo.jdisc.http.filter.DiscFilterRequest;
+import com.yahoo.jdisc.http.filter.security.cors.CorsFilterConfig;
+import com.yahoo.jdisc.http.filter.security.cors.CorsRequestFilterBase;
+import com.yahoo.log.LogLevel;
+import com.yahoo.restapi.Path;
+import com.yahoo.vespa.athenz.api.AthenzDomain;
+import com.yahoo.vespa.athenz.api.AthenzIdentity;
+import com.yahoo.vespa.athenz.api.AthenzPrincipal;
+import com.yahoo.vespa.athenz.client.zms.ZmsClientException;
+import com.yahoo.vespa.hosted.controller.Controller;
+import com.yahoo.vespa.hosted.controller.TenantController;
+import com.yahoo.vespa.hosted.controller.api.role.Role;
+import com.yahoo.vespa.hosted.controller.api.role.Roles;
+import com.yahoo.vespa.hosted.controller.athenz.ApplicationAction;
+import com.yahoo.vespa.hosted.controller.athenz.impl.AthenzFacade;
+import com.yahoo.vespa.hosted.controller.api.role.SecurityContext;
+import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant;
+import com.yahoo.vespa.hosted.controller.tenant.Tenant;
+import com.yahoo.vespa.hosted.controller.tenant.UserTenant;
+import com.yahoo.yolean.Exceptions;
+
+import java.net.URI;
+import java.util.Optional;
+import java.util.Set;
+import java.util.logging.Logger;
+
+import static com.yahoo.vespa.hosted.controller.athenz.HostedAthenzIdentities.SCREWDRIVER_DOMAIN;
+
+/**
+ * Enriches the request principal with roles from Athenz.
+ *
+ * @author jonmv
+ */
+public class AthenzRoleFilter extends CorsRequestFilterBase { // TODO: No need for this super anyway.
+
+ private static final Logger logger = Logger.getLogger(AthenzRoleFilter.class.getName());
+
+ private final AthenzFacade athenz;
+ private final TenantController tenants;
+ private final Roles roles;
+
+ @Inject
+ public AthenzRoleFilter(CorsFilterConfig config, AthenzFacade athenz, Controller controller) {
+ super(Set.copyOf(config.allowedUrls()));
+ this.athenz = athenz;
+ this.tenants = controller.tenants();
+ this.roles = new Roles(controller.system());
+ }
+
+ @Override
+ protected Optional<ErrorResponse> filterRequest(DiscFilterRequest request) {
+ try {
+ AthenzPrincipal athenzPrincipal = (AthenzPrincipal) request.getUserPrincipal();
+ request.setAttribute(SecurityContext.ATTRIBUTE_NAME, new SecurityContext(athenzPrincipal,
+ roles(athenzPrincipal, request.getUri())));
+ return Optional.empty();
+ }
+ catch (Exception e) {
+ logger.log(LogLevel.DEBUG, () -> "Exception mapping Athenz principal to roles: " + Exceptions.toMessageString(e));
+ return Optional.of(new ErrorResponse(Response.Status.UNAUTHORIZED, "Access denied"));
+ }
+ }
+
+ Set<Role> roles(AthenzPrincipal principal, URI uri) {
+ Path path = new Path(uri);
+
+ path.matches("/application/v4/tenant/{tenant}/{*}");
+ Optional<Tenant> tenant = Optional.ofNullable(path.get("tenant")).map(TenantName::from).flatMap(tenants::get);
+
+ path.matches("/application/v4/tenant/{tenant}/application/{application}/{*}");
+ Optional<ApplicationName> application = Optional.ofNullable(path.get("application")).map(ApplicationName::from);
+
+ AthenzIdentity identity = principal.getIdentity();
+
+ if (athenz.hasHostedOperatorAccess(identity))
+ return Set.of(roles.hostedOperator());
+
+ if (tenant.isPresent() && isTenantAdmin(identity, tenant.get()))
+ return Set.of(roles.athenzTenantAdmin(tenant.get().name()));
+
+ if (identity.getDomain().equals(SCREWDRIVER_DOMAIN) && application.isPresent() && tenant.isPresent())
+ // NOTE: Only fine-grained deploy authorization for Athenz tenants
+ if ( tenant.get().type() != Tenant.Type.athenz
+ || hasDeployerAccess(identity, ((AthenzTenant) tenant.get()).domain(), application.get()))
+ return Set.of(roles.tenantPipeline(tenant.get().name(), application.get()));
+
+ return Set.of(roles.everyone());
+ }
+
+ private boolean isTenantAdmin(AthenzIdentity identity, Tenant tenant) {
+ switch (tenant.type()) {
+ case athenz: return athenz.hasTenantAdminAccess(identity, ((AthenzTenant) tenant).domain());
+ case user: return ((UserTenant) tenant).is(identity.getName()) || athenz.hasHostedOperatorAccess(identity);
+ default: throw new IllegalArgumentException("Unexpected tenant type '" + tenant.type() + "'.");
+ }
+ }
+
+ private boolean hasDeployerAccess(AthenzIdentity identity, AthenzDomain tenantDomain, ApplicationName application) {
+ try {
+ return athenz.hasApplicationAccess(identity,
+ ApplicationAction.deploy,
+ tenantDomain,
+ application);
+ } catch (ZmsClientException e) {
+ throw new RuntimeException("Failed to authorize operation: (" + e.getMessage() + ")", e);
+ }
+ }
+
+}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolver.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolver.java
deleted file mode 100644
index a1dfdbeb245..00000000000
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolver.java
+++ /dev/null
@@ -1,118 +0,0 @@
-package com.yahoo.vespa.hosted.controller.restapi.filter;
-
-import com.google.inject.Inject;
-import com.yahoo.config.provision.ApplicationName;
-import com.yahoo.config.provision.SystemName;
-import com.yahoo.config.provision.TenantName;
-import com.yahoo.restapi.Path;
-import com.yahoo.vespa.athenz.api.AthenzDomain;
-import com.yahoo.vespa.athenz.api.AthenzIdentity;
-import com.yahoo.vespa.athenz.api.AthenzPrincipal;
-import com.yahoo.vespa.athenz.api.AthenzUser;
-import com.yahoo.vespa.athenz.client.zms.ZmsClientException;
-import com.yahoo.vespa.hosted.controller.Controller;
-import com.yahoo.vespa.hosted.controller.TenantController;
-import com.yahoo.vespa.hosted.controller.athenz.ApplicationAction;
-import com.yahoo.vespa.hosted.controller.athenz.impl.AthenzFacade;
-import com.yahoo.vespa.hosted.controller.role.Role;
-import com.yahoo.vespa.hosted.controller.role.RoleMembership;
-import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant;
-import com.yahoo.vespa.hosted.controller.tenant.Tenant;
-import com.yahoo.vespa.hosted.controller.tenant.UserTenant;
-
-import javax.ws.rs.InternalServerErrorException;
-import java.security.Principal;
-import java.util.Optional;
-
-import static com.yahoo.vespa.hosted.controller.athenz.HostedAthenzIdentities.SCREWDRIVER_DOMAIN;
-
-/**
- * Translates Athenz principals to role memberships for use in access control.
- *
- * @author tokle
- * @author mpolden
- */
-public class AthenzRoleResolver implements RoleMembership.Resolver {
-
- private final AthenzFacade athenz;
- private final TenantController tenants;
- private final SystemName system;
-
- @Inject
- public AthenzRoleResolver(AthenzFacade athenz, Controller controller) {
- this.athenz = athenz;
- this.tenants = controller.tenants();
- this.system = controller.system();
- }
-
- private boolean isTenantAdmin(AthenzIdentity identity, Tenant tenant) {
- if (tenant instanceof AthenzTenant) {
- return athenz.hasTenantAdminAccess(identity, ((AthenzTenant) tenant).domain());
- } else if (tenant instanceof UserTenant) {
- if (!(identity instanceof AthenzUser)) {
- return false;
- }
- AthenzUser user = (AthenzUser) identity;
- return ((UserTenant) tenant).is(user.getName()) || isHostedOperator(identity);
- }
- throw new InternalServerErrorException("Unknown tenant type: " + tenant.getClass().getSimpleName());
- }
-
- private boolean hasDeployerAccess(AthenzIdentity identity, AthenzDomain tenantDomain, ApplicationName application) {
- try {
- return athenz.hasApplicationAccess(identity,
- ApplicationAction.deploy,
- tenantDomain,
- application);
- } catch (ZmsClientException e) {
- throw new InternalServerErrorException("Failed to authorize operation: (" + e.getMessage() + ")", e);
- }
- }
-
- private boolean isHostedOperator(AthenzIdentity identity) {
- return athenz.hasHostedOperatorAccess(identity);
- }
-
- @Override
- public RoleMembership membership(Principal principal, Optional<String> uriPath) {
- if ( ! (principal instanceof AthenzPrincipal))
- throw new IllegalStateException("Expected an AthenzPrincipal to be set on the request.");
-
- @SuppressWarnings("deprecation") // TODO: Use URI when refactoring this.
- Path path = new Path(uriPath.orElseThrow(() -> new IllegalArgumentException("This resolver needs the request path.")));
-
- path.matches("/application/v4/tenant/{tenant}/{*}");
- Optional<Tenant> tenant = Optional.ofNullable(path.get("tenant")).map(TenantName::from).flatMap(tenants::get);
-
- path.matches("/application/v4/tenant/{tenant}/application/{application}/{*}");
- Optional<ApplicationName> application = Optional.ofNullable(path.get("application")).map(ApplicationName::from);
-
- AthenzIdentity identity = ((AthenzPrincipal) principal).getIdentity();
-
- RoleMembership.Builder memberships = RoleMembership.in(system);
- if (isHostedOperator(identity)) {
- memberships.add(Role.hostedOperator);
- }
- if (tenant.isPresent() && isTenantAdmin(identity, tenant.get())) {
- memberships.add(Role.athenzTenantAdmin).limitedTo(tenant.get().name());
- }
- AthenzDomain principalDomain = identity.getDomain();
- if (principalDomain.equals(SCREWDRIVER_DOMAIN)) {
- if (application.isPresent() && tenant.isPresent()) {
- // NOTE: Only fine-grained deploy authorization for Athenz tenants
- if (tenant.get() instanceof AthenzTenant) {
- AthenzDomain tenantDomain = ((AthenzTenant) tenant.get()).domain();
- if (hasDeployerAccess(identity, tenantDomain, application.get())) {
- memberships.add(Role.tenantPipeline).limitedTo(tenant.get().name(), application.get());
- }
- }
- else {
- memberships.add(Role.tenantPipeline).limitedTo(tenant.get().name(), application.get());
- }
- }
- }
- memberships.add(Role.everyone);
- return memberships.build();
- }
-
-}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java
index dfcc5f732f8..39736d709d0 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.hosted.controller.restapi.filter;
import com.google.inject.Inject;
+import com.yahoo.config.provision.SystemName;
import com.yahoo.jdisc.Response;
import com.yahoo.jdisc.http.HttpRequest;
import com.yahoo.jdisc.http.filter.DiscFilterRequest;
@@ -9,12 +10,11 @@ import com.yahoo.jdisc.http.filter.security.cors.CorsFilterConfig;
import com.yahoo.jdisc.http.filter.security.cors.CorsRequestFilterBase;
import com.yahoo.log.LogLevel;
import com.yahoo.vespa.hosted.controller.Controller;
-import com.yahoo.vespa.hosted.controller.role.Action;
-import com.yahoo.vespa.hosted.controller.role.RoleMembership;
-import com.yahoo.yolean.chain.After;
-import com.yahoo.yolean.chain.Provides;
+import com.yahoo.vespa.hosted.controller.api.role.Action;
+import com.yahoo.vespa.hosted.controller.api.role.Role;
+import com.yahoo.vespa.hosted.controller.api.role.Roles;
+import com.yahoo.vespa.hosted.controller.api.role.SecurityContext;
-import javax.ws.rs.WebApplicationException;
import java.security.Principal;
import java.util.Optional;
import java.util.Set;
@@ -25,45 +25,41 @@ import java.util.logging.Logger;
*
* @author bjorncs
*/
-@After("com.yahoo.vespa.hosted.controller.athenz.filter.UserAuthWithAthenzPrincipalFilter")
-@Provides("ControllerAuthorizationFilter")
public class ControllerAuthorizationFilter extends CorsRequestFilterBase {
private static final Logger log = Logger.getLogger(ControllerAuthorizationFilter.class.getName());
- private final RoleMembership.Resolver roleResolver;
- private final Controller controller;
+ private final Roles roles;
@Inject
- public ControllerAuthorizationFilter(RoleMembership.Resolver roleResolver,
- Controller controller,
+ public ControllerAuthorizationFilter(Controller controller,
CorsFilterConfig corsConfig) {
- this(roleResolver, controller, Set.copyOf(corsConfig.allowedUrls()));
+ this(controller.system(), Set.copyOf(corsConfig.allowedUrls()));
}
- ControllerAuthorizationFilter(RoleMembership.Resolver roleResolver,
- Controller controller,
+ ControllerAuthorizationFilter(SystemName system,
Set<String> allowedUrls) {
super(allowedUrls);
- this.roleResolver = roleResolver;
- this.controller = controller;
+ this.roles = new Roles(system);
}
@Override
public Optional<ErrorResponse> filterRequest(DiscFilterRequest request) {
try {
Principal principal = request.getUserPrincipal();
- if (principal == null)
+ Optional<SecurityContext> securityContext = Optional.ofNullable((SecurityContext)request.getAttribute(SecurityContext.ATTRIBUTE_NAME));
+
+ if (securityContext.isEmpty())
return Optional.of(new ErrorResponse(Response.Status.FORBIDDEN, "Access denied"));
Action action = Action.from(HttpRequest.Method.valueOf(request.getMethod()));
- // Avoid expensive lookups when request is always legal.
- if (RoleMembership.everyoneIn(controller.system()).allows(action, request.getUri()))
+ // Avoid expensive look-ups when request is always legal.
+ if (roles.everyone().allows(action, request.getUri()))
return Optional.empty();
- RoleMembership roles = this.roleResolver.membership(principal, Optional.of(request.getRequestURI()));
- if (roles.allows(action, request.getUri()))
+ Set<Role> roles = securityContext.get().roles();
+ if (roles.stream().anyMatch(role -> role.allows(action, request.getUri())))
return Optional.empty();
}
catch (Exception e) {
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/user/UserApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/user/UserApiHandler.java
index 18b124778d5..067e6095b4d 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/user/UserApiHandler.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/user/UserApiHandler.java
@@ -85,4 +85,6 @@ public class UserApiHandler extends LoggingRequestHandler {
return response;
}
+
+
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/RoleMembership.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/RoleMembership.java
deleted file mode 100644
index 09e66528913..00000000000
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/RoleMembership.java
+++ /dev/null
@@ -1,122 +0,0 @@
-// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.controller.role;
-
-import com.yahoo.config.provision.ApplicationName;
-import com.yahoo.config.provision.SystemName;
-import com.yahoo.config.provision.TenantName;
-
-import java.net.URI;
-import java.security.Principal;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
-
-/**
- * A list of roles and their associated contexts. This defines the role membership of a tenant, and in which contexts
- * (see {@link Context}) those roles apply.
- *
- * @author mpolden
- * @author jonmv
- */
-public class RoleMembership {
-
- private final Map<Role, Set<Context>> roles;
-
- private RoleMembership(Map<Role, Set<Context>> roles) {
- this.roles = roles.entrySet().stream()
- .collect(Collectors.toUnmodifiableMap(entry -> entry.getKey(),
- entry -> Set.copyOf(entry.getValue())));
- }
-
- public static RoleMembership everyoneIn(SystemName system) {
- return in(system).add(Role.everyone).build();
- }
-
- public static Builder in(SystemName system) { return new BuilderWithRole(system); }
-
- /** Returns whether any role in this allows action to take place in path */
- public boolean allows(Action action, URI uri) {
- return roles.entrySet().stream().anyMatch(kv -> {
- Role role = kv.getKey();
- Set<Context> contexts = kv.getValue();
- return contexts.stream().anyMatch(context -> role.allows(action, uri, context));
- });
- }
-
- /** Returns the set of contexts for which the given role is valid. */
- public Set<Context> contextsFor(Role role) {
- return roles.getOrDefault(role, Collections.emptySet());
- }
-
- @Override
- public String toString() {
- return "roles " + roles;
- }
-
- /**
- * A role resolver. Identity providers can implement this to translate their internal representation of role
- * membership to a {@link RoleMembership}.
- */
- public interface Resolver {
- RoleMembership membership(Principal user, Optional<String> path); // TODO get rid of path.
- }
-
- public interface Builder {
-
- BuilderWithRole add(Role role);
-
- RoleMembership build();
-
- }
-
- public static class BuilderWithRole implements Builder {
-
- private final SystemName system;
- private final Map<Role, Set<Context>> roles;
-
- private Role current;
-
- private BuilderWithRole(SystemName system) {
- this.system = Objects.requireNonNull(system);
- this.roles = new HashMap<>();
- }
-
- @Override
- public BuilderWithRole add(Role role) {
- consumeCurrent(Context.unlimitedIn(system));
- current = role;
- return this;
- }
-
- public Builder limitedTo(TenantName tenant) {
- consumeCurrent(Context.limitedTo(tenant, system));
- return this;
- }
-
- public Builder limitedTo(TenantName tenant, ApplicationName application) {
- consumeCurrent(Context.limitedTo(tenant, application, system));
- return this;
- }
-
- @Override
- public RoleMembership build() {
- consumeCurrent(Context.unlimitedIn(system));
- return new RoleMembership(roles);
- }
-
- private void consumeCurrent(Context context) {
- if (current != null) {
- roles.putIfAbsent(current, new HashSet<>());
- roles.get(current).add(context);
- }
- current = null;
- }
-
- }
-
-}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/rotation/RotationRepository.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/rotation/RotationRepository.java
index a22e5259919..b3953c47c01 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/rotation/RotationRepository.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/rotation/RotationRepository.java
@@ -63,7 +63,7 @@ public class RotationRepository {
if (application.rotation().isPresent()) {
return allRotations.get(application.rotation().get());
}
- if (!application.deploymentSpec().globalServiceId().isPresent()) {
+ if (application.deploymentSpec().globalServiceId().isEmpty()) {
throw new IllegalArgumentException("global-service-id is not set in deployment spec");
}
long productionZones = application.deploymentSpec().zones().stream()
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java
index 67d7a02a915..d1806fb5747 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java
@@ -2,10 +2,17 @@ package com.yahoo.vespa.hosted.controller.security;
import com.google.inject.Inject;
import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.config.provision.ApplicationName;
import com.yahoo.config.provision.TenantName;
import com.yahoo.vespa.hosted.controller.Application;
import com.yahoo.vespa.hosted.controller.api.integration.organization.BillingInfo;
import com.yahoo.vespa.hosted.controller.api.integration.organization.Marketplace;
+import com.yahoo.vespa.hosted.controller.api.integration.user.RoleId;
+import com.yahoo.vespa.hosted.controller.api.integration.user.UserId;
+import com.yahoo.vespa.hosted.controller.api.integration.user.UserManagement;
+import com.yahoo.vespa.hosted.controller.api.role.ApplicationRole;
+import com.yahoo.vespa.hosted.controller.api.role.Roles;
+import com.yahoo.vespa.hosted.controller.api.role.TenantRole;
import com.yahoo.vespa.hosted.controller.tenant.CloudTenant;
import com.yahoo.vespa.hosted.controller.tenant.Tenant;
@@ -19,21 +26,28 @@ import java.util.List;
public class CloudAccessControl implements AccessControl {
private final Marketplace marketplace;
+ private final UserManagement userManagement;
+ private final Roles roles;
@Inject
- public CloudAccessControl(Marketplace marketplace) {
+ public CloudAccessControl(Marketplace marketplace, UserManagement userManagement, Roles roles) {
this.marketplace = marketplace;
+ this.userManagement = userManagement;
+ this.roles = roles;
}
@Override
public CloudTenant createTenant(TenantSpec tenantSpec, Credentials credentials, List<Tenant> existing) {
CloudTenantSpec spec = (CloudTenantSpec) tenantSpec;
+ CloudTenant tenant = new CloudTenant(spec.tenant(), new BillingInfo("customer", "Vespa"));
+ // CloudTenant tenant new CloudTenant(spec.tenant(), marketplace.resolveCustomer(spec.getRegistrationToken()));
+ // TODO Enable the above when things work.
- // Do things ...
+ RoleId ownerRole = RoleId.fromRole(roles.tenantOwner(spec.tenant()));
+ userManagement.createRole(ownerRole);
+ userManagement.addUsers(ownerRole, List.of(new UserId(credentials.user().getName())));
- // return new CloudTenant(spec.tenant(), marketplace.resolveCustomer(spec.getRegistrationToken()));
- // TODO Enable the above when things work.
- return new CloudTenant(spec.tenant(), new BillingInfo("customer", "Vespa"));
+ return tenant;
}
@Override
@@ -43,31 +57,48 @@ public class CloudAccessControl implements AccessControl {
@Override
public void deleteTenant(TenantName tenant, Credentials credentials) {
-
// Probably terminate customer subscription?
- // Delete tenant group
-
+ tenantRoles(tenant).stream()
+ .map(RoleId::fromRole)
+ .filter(userManagement.listRoles()::contains)
+ .forEach(userManagement::deleteRole);
}
@Override
public void createApplication(ApplicationId application, Credentials credentials) {
-
- // Create application group?
-
+ RoleId ownerRole = RoleId.fromRole(roles.applicationOwner(application.tenant(), application.application()));
+ userManagement.createRole(ownerRole);
+ userManagement.addUsers(ownerRole, List.of(new UserId(credentials.user().getName())));
}
@Override
public void deleteApplication(ApplicationId id, Credentials credentials) {
-
- // Delete application group?
-
+ applicationRoles(id.tenant(), id.application()).stream()
+ .map(RoleId::fromRole)
+ .filter(userManagement.listRoles()::contains)
+ .forEach(userManagement::deleteRole);
}
@Override
public List<Tenant> accessibleTenants(List<Tenant> tenants, Credentials credentials) {
- // Get credential things (token with roles or something) and check what it's good for.
+ // TODO: Get credential things (token with roles or something) and check what it's good for.
+ // TODO ... or ignore this here, and compute it somewhere else.
return Collections.emptyList();
}
+ private List<TenantRole> tenantRoles(TenantName tenant) {
+ return List.of(roles.tenantOperator(tenant),
+ roles.tenantAdmin(tenant),
+ roles.tenantOwner(tenant));
+ }
+
+ private List<ApplicationRole> applicationRoles(TenantName tenant, ApplicationName application) {
+ return List.of(roles.applicationReader(tenant, application),
+ roles.applicationDeveloper(tenant, application),
+ roles.applicationOperator(tenant, application),
+ roles.applicationAdmin(tenant, application),
+ roles.applicationOwner(tenant, application));
+ }
+
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControlRequests.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControlRequests.java
index 631d4debe88..ea931616211 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControlRequests.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControlRequests.java
@@ -20,7 +20,7 @@ public class CloudAccessControlRequests implements AccessControlRequests {
@Override
public Credentials credentials(TenantName tenant, Inspector requestObject, HttpRequest request) {
- // TODO Pick out JWT data and return a specialised credentials thing.
+ // TODO Include roles, if this is to be used for displaying accessible data.
return new Credentials(request.getUserPrincipal());
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/TenantSpec.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/TenantSpec.java
index 2f7dd656678..358088e9b08 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/TenantSpec.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/TenantSpec.java
@@ -1,6 +1,7 @@
package com.yahoo.vespa.hosted.controller.security;
import com.yahoo.config.provision.TenantName;
+import com.yahoo.vespa.hosted.controller.tenant.Tenant;
import static java.util.Objects.requireNonNull;
@@ -14,7 +15,7 @@ public abstract class TenantSpec {
private final TenantName tenant;
protected TenantSpec(TenantName tenant) {
- this.tenant = requireNonNull(tenant);
+ this.tenant = Tenant.requireName(requireNonNull(tenant));
}
/** The name of the tenant. */
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Tenant.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Tenant.java
index 19b7229515b..e0c750dec80 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Tenant.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Tenant.java
@@ -49,7 +49,7 @@ public abstract class Tenant {
return Objects.hash(name);
}
- static TenantName requireName(TenantName name) {
+ public static TenantName requireName(TenantName name) {
if ( ! name.value().matches("^(?=.{1,20}$)[a-z](-?[a-z0-9]+)*$")) {
throw new IllegalArgumentException("New tenant or application names must start with a letter, may " +
"contain no more than 20 characters, and may only contain lowercase " +
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java
index 1f00d99350a..bc42b672da4 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java
@@ -12,13 +12,14 @@ import com.yahoo.config.provision.InstanceName;
import com.yahoo.config.provision.RegionName;
import com.yahoo.config.provision.SystemName;
import com.yahoo.config.provision.TenantName;
+import com.yahoo.vespa.flags.Flags;
+import com.yahoo.vespa.flags.InMemoryFlagSource;
import com.yahoo.vespa.hosted.controller.api.application.v4.model.DeployOptions;
import com.yahoo.vespa.hosted.controller.api.application.v4.model.EndpointStatus;
import com.yahoo.vespa.hosted.controller.api.identifiers.DeploymentId;
import com.yahoo.vespa.hosted.controller.api.integration.deployment.ApplicationVersion;
import com.yahoo.vespa.hosted.controller.api.integration.deployment.SourceRevision;
import com.yahoo.vespa.hosted.controller.api.integration.dns.Record;
-import com.yahoo.vespa.hosted.controller.api.integration.dns.RecordName;
import com.yahoo.vespa.hosted.controller.api.integration.routing.RoutingEndpoint;
import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId;
import com.yahoo.vespa.hosted.controller.application.ApplicationPackage;
@@ -281,40 +282,60 @@ public class ControllerTest {
.region("us-central-1") // Two deployments should result in each DNS alias being registered once
.build();
- Function<String, Optional<Record>> findCname = (name) -> tester.controllerTester().nameService()
- .findRecords(Record.Type.CNAME,
- RecordName.from(name))
- .stream()
- .findFirst();
-
tester.deployCompletely(application, applicationPackage);
assertEquals(3, tester.controllerTester().nameService().records().size());
- Optional<Record> record = findCname.apply("app1--tenant1.global.vespa.yahooapis.com");
+ Optional<Record> record = tester.controllerTester().findCname("app1--tenant1.global.vespa.yahooapis.com");
assertTrue(record.isPresent());
assertEquals("app1--tenant1.global.vespa.yahooapis.com", record.get().name().asString());
assertEquals("rotation-fqdn-01.", record.get().data().asString());
- record = findCname.apply("app1--tenant1.global.vespa.oath.cloud");
+ record = tester.controllerTester().findCname("app1--tenant1.global.vespa.oath.cloud");
assertTrue(record.isPresent());
assertEquals("app1--tenant1.global.vespa.oath.cloud", record.get().name().asString());
assertEquals("rotation-fqdn-01.", record.get().data().asString());
- record = findCname.apply("app1.tenant1.global.vespa.yahooapis.com");
+ record = tester.controllerTester().findCname("app1.tenant1.global.vespa.yahooapis.com");
assertTrue(record.isPresent());
assertEquals("app1.tenant1.global.vespa.yahooapis.com", record.get().name().asString());
assertEquals("rotation-fqdn-01.", record.get().data().asString());
}
@Test
- public void testUpdatesExistingDnsAlias() {
+ public void testRedirectLegacyDnsNames() { // TODO: Remove together with Flags.REDIRECT_LEGACY_DNS_NAMES
DeploymentTester tester = new DeploymentTester();
+ Application application = tester.createApplication("app1", "tenant1", 1, 1L);
+ ApplicationPackage applicationPackage = new ApplicationPackageBuilder()
+ .environment(Environment.prod)
+ .globalServiceId("foo")
+ .region("us-west-1")
+ .region("us-central-1")
+ .build();
+
+ ((InMemoryFlagSource) tester.controller().flagSource()).withBooleanFlag(Flags.REDIRECT_LEGACY_DNS_NAMES.id(), true);
+
+ tester.deployCompletely(application, applicationPackage);
+ assertEquals(3, tester.controllerTester().nameService().records().size());
+
+ Optional<Record> record = tester.controllerTester().findCname("app1--tenant1.global.vespa.yahooapis.com");
+ assertTrue(record.isPresent());
+ assertEquals("app1--tenant1.global.vespa.yahooapis.com", record.get().name().asString());
+ assertEquals("app1--tenant1.global.vespa.oath.cloud.", record.get().data().asString());
- Function<String, Optional<Record>> findCname = (name) -> tester.controllerTester().nameService()
- .findRecords(Record.Type.CNAME,
- RecordName.from(name))
- .stream()
- .findFirst();
+ record = tester.controllerTester().findCname("app1--tenant1.global.vespa.oath.cloud");
+ assertTrue(record.isPresent());
+ assertEquals("app1--tenant1.global.vespa.oath.cloud", record.get().name().asString());
+ assertEquals("rotation-fqdn-01.", record.get().data().asString());
+
+ record = tester.controllerTester().findCname("app1.tenant1.global.vespa.yahooapis.com");
+ assertTrue(record.isPresent());
+ assertEquals("app1.tenant1.global.vespa.yahooapis.com", record.get().name().asString());
+ assertEquals("app1--tenant1.global.vespa.oath.cloud.", record.get().data().asString());
+ }
+
+ @Test
+ public void testUpdatesExistingDnsAlias() {
+ DeploymentTester tester = new DeploymentTester();
// Application 1 is deployed and deleted
{
@@ -329,12 +350,12 @@ public class ControllerTest {
tester.deployCompletely(app1, applicationPackage);
assertEquals(3, tester.controllerTester().nameService().records().size());
- Optional<Record> record = findCname.apply("app1--tenant1.global.vespa.yahooapis.com");
+ Optional<Record> record = tester.controllerTester().findCname("app1--tenant1.global.vespa.yahooapis.com");
assertTrue(record.isPresent());
assertEquals("app1--tenant1.global.vespa.yahooapis.com", record.get().name().asString());
assertEquals("rotation-fqdn-01.", record.get().data().asString());
- record = findCname.apply("app1.tenant1.global.vespa.yahooapis.com");
+ record = tester.controllerTester().findCname("app1.tenant1.global.vespa.yahooapis.com");
assertTrue(record.isPresent());
assertEquals("app1.tenant1.global.vespa.yahooapis.com", record.get().name().asString());
assertEquals("rotation-fqdn-01.", record.get().data().asString());
@@ -356,13 +377,13 @@ public class ControllerTest {
}
// Records remain
- record = findCname.apply("app1--tenant1.global.vespa.yahooapis.com");
+ record = tester.controllerTester().findCname("app1--tenant1.global.vespa.yahooapis.com");
assertTrue(record.isPresent());
- record = findCname.apply("app1--tenant1.global.vespa.oath.cloud");
+ record = tester.controllerTester().findCname("app1--tenant1.global.vespa.oath.cloud");
assertTrue(record.isPresent());
- record = findCname.apply("app1.tenant1.global.vespa.yahooapis.com");
+ record = tester.controllerTester().findCname("app1.tenant1.global.vespa.yahooapis.com");
assertTrue(record.isPresent());
}
@@ -378,17 +399,17 @@ public class ControllerTest {
tester.deployCompletely(app2, applicationPackage);
assertEquals(6, tester.controllerTester().nameService().records().size());
- Optional<Record> record = findCname.apply("app2--tenant2.global.vespa.yahooapis.com");
+ Optional<Record> record = tester.controllerTester().findCname("app2--tenant2.global.vespa.yahooapis.com");
assertTrue(record.isPresent());
assertEquals("app2--tenant2.global.vespa.yahooapis.com", record.get().name().asString());
assertEquals("rotation-fqdn-01.", record.get().data().asString());
- record = findCname.apply("app2--tenant2.global.vespa.oath.cloud");
+ record = tester.controllerTester().findCname("app2--tenant2.global.vespa.oath.cloud");
assertTrue(record.isPresent());
assertEquals("app2--tenant2.global.vespa.oath.cloud", record.get().name().asString());
assertEquals("rotation-fqdn-01.", record.get().data().asString());
- record = findCname.apply("app2.tenant2.global.vespa.yahooapis.com");
+ record = tester.controllerTester().findCname("app2.tenant2.global.vespa.yahooapis.com");
assertTrue(record.isPresent());
assertEquals("app2.tenant2.global.vespa.yahooapis.com", record.get().name().asString());
assertEquals("rotation-fqdn-01.", record.get().data().asString());
@@ -411,15 +432,15 @@ public class ControllerTest {
// Existing DNS records are updated to point to the newly assigned rotation
assertEquals(6, tester.controllerTester().nameService().records().size());
- Optional<Record> record = findCname.apply("app1--tenant1.global.vespa.yahooapis.com");
+ Optional<Record> record = tester.controllerTester().findCname("app1--tenant1.global.vespa.yahooapis.com");
assertTrue(record.isPresent());
assertEquals("rotation-fqdn-02.", record.get().data().asString());
- record = findCname.apply("app1--tenant1.global.vespa.oath.cloud");
+ record = tester.controllerTester().findCname("app1--tenant1.global.vespa.oath.cloud");
assertTrue(record.isPresent());
assertEquals("rotation-fqdn-02.", record.get().data().asString());
- record = findCname.apply("app1.tenant1.global.vespa.yahooapis.com");
+ record = tester.controllerTester().findCname("app1.tenant1.global.vespa.yahooapis.com");
assertTrue(record.isPresent());
assertEquals("rotation-fqdn-02.", record.get().data().asString());
}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java
index d7845e4bfa1..c18e9c46f07 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java
@@ -13,6 +13,7 @@ import com.yahoo.vespa.athenz.api.AthenzUser;
import com.yahoo.vespa.athenz.api.OktaAccessToken;
import com.yahoo.vespa.curator.Lock;
import com.yahoo.vespa.curator.mock.MockCurator;
+import com.yahoo.vespa.flags.InMemoryFlagSource;
import com.yahoo.vespa.hosted.controller.api.application.v4.model.DeployOptions;
import com.yahoo.vespa.hosted.controller.api.identifiers.Property;
import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId;
@@ -22,14 +23,15 @@ import com.yahoo.vespa.hosted.controller.api.integration.deployment.ApplicationS
import com.yahoo.vespa.hosted.controller.api.integration.deployment.ArtifactRepository;
import com.yahoo.vespa.hosted.controller.api.integration.deployment.JobType;
import com.yahoo.vespa.hosted.controller.api.integration.dns.MemoryNameService;
-import com.yahoo.vespa.hosted.controller.api.integration.entity.MemoryEntityService;
+import com.yahoo.vespa.hosted.controller.api.integration.dns.Record;
+import com.yahoo.vespa.hosted.controller.api.integration.dns.RecordName;
import com.yahoo.vespa.hosted.controller.api.integration.github.GitHubMock;
import com.yahoo.vespa.hosted.controller.api.integration.organization.Contact;
import com.yahoo.vespa.hosted.controller.api.integration.organization.MockContactRetriever;
import com.yahoo.vespa.hosted.controller.api.integration.organization.MockIssueHandler;
-import com.yahoo.vespa.hosted.controller.api.integration.stubs.MockMailer;
import com.yahoo.vespa.hosted.controller.api.integration.routing.RoutingGenerator;
import com.yahoo.vespa.hosted.controller.api.integration.stubs.MockBuildService;
+import com.yahoo.vespa.hosted.controller.api.integration.stubs.MockMailer;
import com.yahoo.vespa.hosted.controller.api.integration.stubs.MockRunDataStore;
import com.yahoo.vespa.hosted.controller.api.integration.stubs.MockTesterCloud;
import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId;
@@ -43,11 +45,11 @@ import com.yahoo.vespa.hosted.controller.integration.ConfigServerMock;
import com.yahoo.vespa.hosted.controller.integration.MetricsServiceMock;
import com.yahoo.vespa.hosted.controller.integration.RoutingGeneratorMock;
import com.yahoo.vespa.hosted.controller.integration.ZoneRegistryMock;
-import com.yahoo.vespa.hosted.controller.security.AthenzCredentials;
-import com.yahoo.vespa.hosted.controller.security.AthenzTenantSpec;
import com.yahoo.vespa.hosted.controller.persistence.ApplicationSerializer;
import com.yahoo.vespa.hosted.controller.persistence.CuratorDb;
import com.yahoo.vespa.hosted.controller.persistence.MockCuratorDb;
+import com.yahoo.vespa.hosted.controller.security.AthenzCredentials;
+import com.yahoo.vespa.hosted.controller.security.AthenzTenantSpec;
import com.yahoo.vespa.hosted.controller.security.Credentials;
import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant;
import com.yahoo.vespa.hosted.controller.tenant.Tenant;
@@ -188,6 +190,10 @@ public final class ControllerTester {
return contactRetriever;
}
+ public Optional<Record> findCname(String name) {
+ return nameService.findRecords(Record.Type.CNAME, RecordName.from(name)).stream().findFirst();
+ }
+
/** Create a new controller instance. Useful to verify that controller state is rebuilt from persistence */
public final void createNewController() {
controller = createController(curator, rotationsConfig, configServer, clock, gitHub, zoneRegistry, athenzDb,
@@ -345,7 +351,8 @@ public final class ControllerTester {
buildService,
new MockRunDataStore(),
() -> "test-controller",
- new MockMailer());
+ new MockMailer(),
+ new InMemoryFlagSource());
// Calculate initial versions
controller.updateVersionStatus(VersionStatus.compute(controller));
return controller;
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainerTest.java
index 23c7ec537f5..2b8e4f52d23 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainerTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainerTest.java
@@ -4,7 +4,6 @@ package com.yahoo.vespa.hosted.controller.maintenance;
import com.yahoo.config.application.api.ValidationId;
import com.yahoo.config.provision.Environment;
import com.yahoo.config.provision.RegionName;
-import com.yahoo.vespa.athenz.api.OktaAccessToken;
import com.yahoo.vespa.hosted.controller.Application;
import com.yahoo.vespa.hosted.controller.ControllerTester;
import com.yahoo.vespa.hosted.controller.api.integration.dns.Record;
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java
index 331a6ba9ac8..c21d4b4b0bf 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java
@@ -63,6 +63,7 @@ public class ControllerContainerTest {
" <item>http://localhost</item>\n" +
" </allowedUrls>\n" +
" </config>\n" +
+ " <component id='com.yahoo.vespa.flags.InMemoryFlagSource'/>\n" +
" <component id='com.yahoo.vespa.hosted.controller.persistence.MockCuratorDb'/>\n" +
" <component id='com.yahoo.vespa.hosted.controller.athenz.mock.AthenzClientFactoryMock'/>\n" +
" <component id='com.yahoo.vespa.hosted.controller.api.integration.chef.ChefMock'/>\n" +
@@ -94,9 +95,9 @@ public class ControllerContainerTest {
" <component id='com.yahoo.vespa.hosted.controller.integration.ApplicationStoreMock'/>\n" +
" <component id='com.yahoo.vespa.hosted.controller.api.integration.stubs.MockTesterCloud'/>\n" +
" <component id='com.yahoo.vespa.hosted.controller.api.integration.stubs.MockMailer'/>\n" +
+ " <component id='com.yahoo.vespa.hosted.controller.api.role.Roles'/>\n" +
" <component id='com.yahoo.vespa.hosted.controller.security.AthenzAccessControlRequests'/>\n" +
" <component id='com.yahoo.vespa.hosted.controller.athenz.impl.AthenzFacade'/>\n" +
- " <component id='com.yahoo.vespa.hosted.controller.restapi.filter.AthenzRoleResolver'/>\n" +
" <handler id='com.yahoo.vespa.hosted.controller.restapi.application.ApplicationApiHandler'>\n" +
" <binding>http://*/application/v4/*</binding>\n" +
" </handler>\n" +
@@ -134,6 +135,7 @@ public class ControllerContainerTest {
" <filtering>\n" +
" <request-chain id='default'>\n" +
" <filter id='com.yahoo.vespa.hosted.controller.integration.AthenzFilterMock'/>\n" +
+ " <filter id='com.yahoo.vespa.hosted.controller.restapi.filter.AthenzRoleFilter'/>\n" +
" <filter id='com.yahoo.vespa.hosted.controller.restapi.filter.ControllerAuthorizationFilter'/>\n" +
" <binding>http://*/*</binding>\n" +
" </request-chain>\n" +
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java
index 40d39248cb5..bde1c037bf2 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java
@@ -384,7 +384,7 @@ public class ApplicationApiTest extends ControllerContainerTest {
tester.assertResponse(request("/application/v4/tenant/tenant2/application/application1/environment/prod/region/us-central-1/instance/default/logs?from=1233&to=3214", GET)
.userIdentity(USER_ID),
new File("logs.json"));
- tester.assertResponse(request("/application/v4/tenant/tenant2/application/application1/environment/prod/region/us-central-1/instance/default/logs?from=1233&to=3214&streaming", GET)
+ tester.assertResponse(request("/application/v4/tenant/tenant2/application/application1/environment/dev/region/us-central-1/instance/default/logs?from=1233&to=3214&streaming", GET)
.userIdentity(USER_ID),
"INFO - All good");
@@ -758,31 +758,6 @@ public class ApplicationApiTest extends ControllerContainerTest {
new File("deploy-no-deployment.json"), 400);
}
- // Tests deployment to config server when using just on API call
- // For now this depends on a switch in ApplicationController that does this for by- tenants in CD only
- @Test
- public void testDeployDirectlyUsingOneCallForDeploy() {
- // Setup
- tester.computeVersionStatus();
- UserId userId = new UserId("new_user");
- createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN, userId);
-
- // Create tenant
- // PUT (create) the authenticated user
- byte[] data = new byte[0];
- tester.assertResponse(request("/application/v4/user?user=new_user&domain=by", PUT)
- .data(data)
- .userIdentity(userId), // Normalized to by-new-user by API
- new File("create-user-response.json"));
-
- // POST (deploy) an application to a dev zone
- HttpEntity entity = createApplicationDeployData(applicationPackage, true);
- tester.assertResponse(request("/application/v4/tenant/by-new-user/application/application1/environment/dev/region/cd-us-central-1/instance/default", POST)
- .data(entity)
- .userIdentity(userId),
- new File("deploy-result.json"));
- }
-
@Test
public void testSortsDeploymentsAndJobs() {
tester.computeVersionStatus();
@@ -897,7 +872,7 @@ public class ApplicationApiTest extends ControllerContainerTest {
"{\"error-code\":\"BAD_REQUEST\",\"message\":\"Tenant 'tenant1' already exists\"}",
400);
- // POST (add) a Athenz tenant with underscore in name
+ // POST (add) an Athenz tenant with underscore in name
tester.assertResponse(request("/application/v4/tenant/my_tenant_2", POST)
.userIdentity(USER_ID)
.data("{\"athensDomain\":\"domain1\", \"property\":\"property1\"}")
@@ -905,7 +880,7 @@ public class ApplicationApiTest extends ControllerContainerTest {
"{\"error-code\":\"BAD_REQUEST\",\"message\":\"New tenant or application names must start with a letter, may contain no more than 20 characters, and may only contain lowercase letters, digits or dashes, but no double-dashes.\"}",
400);
- // POST (add) a Athenz tenant with by- prefix
+ // POST (add) an Athenz tenant with by- prefix
tester.assertResponse(request("/application/v4/tenant/by-tenant2", POST)
.userIdentity(USER_ID)
.data("{\"athensDomain\":\"domain1\", \"property\":\"property1\"}")
@@ -913,6 +888,14 @@ public class ApplicationApiTest extends ControllerContainerTest {
"{\"error-code\":\"BAD_REQUEST\",\"message\":\"Athenz tenant name cannot have prefix 'by-'\"}",
400);
+ // POST (add) an Athenz tenant with a reserved name
+ tester.assertResponse(request("/application/v4/tenant/hosted-vespa", POST)
+ .userIdentity(USER_ID)
+ .data("{\"athensDomain\":\"domain1\", \"property\":\"property1\"}")
+ .oktaAccessToken(OKTA_AT),
+ "{\"error-code\":\"BAD_REQUEST\",\"message\":\"Tenant 'hosted-vespa' already exists\"}",
+ 400);
+
// POST (create) an (empty) application
tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1", POST)
.userIdentity(USER_ID)
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-cluster-global-rotation.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-cluster-global-rotation.json
index c31a47cb5b2..cd531bb96da 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-cluster-global-rotation.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-cluster-global-rotation.json
@@ -1,4 +1,5 @@
{
+ "tenant": "tenant1",
"application": "application1",
"instance": "default",
"deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/",
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference-2.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference-2.json
index a4026d6a812..ff22b95739d 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference-2.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference-2.json
@@ -1,4 +1,5 @@
{
+ "tenant": "tenant2",
"application": "application2",
"instance": "default",
"url": "http://localhost:8080/application/v4/tenant/tenant2/application/application2"
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference.json
index 1ec229a2b4a..1d56944f6bc 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference.json
@@ -1,4 +1,5 @@
{
+ "tenant": "tenant1",
"application":"application1",
"instance":"default",
"url":"http://localhost:8080/application/v4/tenant/tenant1/application/application1"
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json
index 9f65b5952e1..f2f38f7f509 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json
@@ -1,4 +1,5 @@
{
+ "tenant": "tenant1",
"application": "application1",
"instance": "default",
"deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/",
@@ -237,21 +238,21 @@
"rotationId": "rotation-id-1",
"instances": [
{
- "environment": "prod",
- "region": "us-west-1",
- "instance": "default",
"bcpStatus": {
"rotationStatus": "IN"
},
+ "environment": "prod",
+ "region": "us-west-1",
+ "instance": "default",
"url": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/environment/prod/region/us-west-1/instance/default"
},
{
- "environment": "prod",
- "region": "us-east-3",
- "instance": "default",
"bcpStatus": {
"rotationStatus": "UNKNOWN"
},
+ "environment": "prod",
+ "region": "us-east-3",
+ "instance": "default",
"url": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/environment/prod/region/us-east-3/instance/default"
}
],
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json
index 3744e44152a..22e8573b1d4 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json
@@ -1,4 +1,5 @@
{
+ "tenant": "tenant1",
"application": "application1",
"instance": "default",
"deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/",
@@ -231,12 +232,12 @@
"url": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/environment/dev/region/us-west-1/instance/default"
},
{
- "environment": "prod",
- "region": "us-central-1",
- "instance": "default",
"bcpStatus": {
"rotationStatus": "UNKNOWN"
},
+ "environment": "prod",
+ "region": "us-central-1",
+ "instance": "default",
"url": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/environment/prod/region/us-central-1/instance/default"
}
],
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json
index 822bc447d8d..662e045d169 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json
@@ -1,4 +1,5 @@
{
+ "tenant": "tenant1",
"application": "application1",
"instance": "default",
"deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/",
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2-with-majorVersion.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2-with-majorVersion.json
index 55803074ade..1477e18b4b8 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2-with-majorVersion.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2-with-majorVersion.json
@@ -1,4 +1,5 @@
{
+ "tenant": "tenant2",
"application": "application2",
"instance": "default",
"deployments": "http://localhost:8080/application/v4/tenant/tenant2/application/application2/instance/default/job/",
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json
index 2c34e5ae712..3063bb62b7e 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json
@@ -1,4 +1,5 @@
{
+ "tenant": "tenant2",
"application": "application2",
"instance": "default",
"deployments": "http://localhost:8080/application/v4/tenant/tenant2/application/application2/instance/default/job/",
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json
index ac1797986fc..af21260676c 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json
@@ -1,4 +1,9 @@
{
+ "tenant": "tenant1",
+ "application": "application1",
+ "instance": "default",
+ "environment": "prod",
+ "region": "us-central-1",
"serviceUrls": [
"http://old-endpoint.vespa.yahooapis.com:4080",
"http://qrs-endpoint.vespa.yahooapis.com:4080",
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json
index 65ea3925d8c..54e94c4521e 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json
@@ -1,7 +1,9 @@
{
+ "tenant": "tenant1",
+ "application": "application1",
+ "instance": "default",
"environment": "dev",
"region": "us-west-1",
- "instance": "default",
"serviceUrls": [
"http://old-endpoint.vespa.yahooapis.com:4080",
"http://qrs-endpoint.vespa.yahooapis.com:4080",
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-us-central-1.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-us-central-1.json
index a3380d823f3..cfefe629b9a 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-us-central-1.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-us-central-1.json
@@ -1,10 +1,12 @@
{
- "environment": "prod",
- "region": "us-central-1",
- "instance": "default",
"bcpStatus": {
"rotationStatus": "UNKNOWN"
},
+ "tenant": "tenant1",
+ "application": "application1",
+ "instance": "default",
+ "environment": "prod",
+ "region": "us-central-1",
"serviceUrls": [
"http://old-endpoint.vespa.yahooapis.com:4080",
"http://qrs-endpoint.vespa.yahooapis.com:4080",
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-application.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-application.json
index ad8e65692b4..b222c33291c 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-application.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-application.json
@@ -5,6 +5,7 @@
"property": "property1",
"applications": [
{
+ "tenant": "tenant1",
"application":"application1",
"instance":"default",
"url":"http://localhost:8080/application/v4/tenant/tenant1/application/application1"
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilterTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilterTest.java
new file mode 100644
index 00000000000..dc4235e52bf
--- /dev/null
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilterTest.java
@@ -0,0 +1,122 @@
+package com.yahoo.vespa.hosted.controller.restapi.filter;
+
+import com.yahoo.config.provision.ApplicationName;
+import com.yahoo.config.provision.TenantName;
+import com.yahoo.jdisc.http.filter.security.cors.CorsFilterConfig;
+import com.yahoo.vespa.athenz.api.AthenzDomain;
+import com.yahoo.vespa.athenz.api.AthenzPrincipal;
+import com.yahoo.vespa.athenz.api.AthenzService;
+import com.yahoo.vespa.athenz.api.AthenzUser;
+import com.yahoo.vespa.hosted.controller.ControllerTester;
+import com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId;
+import com.yahoo.vespa.hosted.controller.api.identifiers.ScrewdriverId;
+import com.yahoo.vespa.hosted.controller.api.role.Roles;
+import com.yahoo.vespa.hosted.controller.athenz.ApplicationAction;
+import com.yahoo.vespa.hosted.controller.athenz.HostedAthenzIdentities;
+import com.yahoo.vespa.hosted.controller.athenz.impl.AthenzFacade;
+import com.yahoo.vespa.hosted.controller.athenz.mock.AthenzClientFactoryMock;
+import com.yahoo.vespa.hosted.controller.athenz.mock.AthenzDbMock;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.net.URI;
+import java.util.Set;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author jonmv
+ */
+public class AthenzRoleFilterTest {
+
+ private static final AthenzPrincipal USER = new AthenzPrincipal(new AthenzUser("john"));
+ private static final AthenzPrincipal HOSTED_OPERATOR = new AthenzPrincipal(new AthenzUser("hosted-operator"));
+ private static final AthenzDomain TENANT_DOMAIN = new AthenzDomain("tenantdomain");
+ private static final AthenzDomain TENANT_DOMAIN2 = new AthenzDomain("tenantdomain2");
+ private static final AthenzPrincipal TENANT_ADMIN = new AthenzPrincipal(new AthenzService(TENANT_DOMAIN, "adminservice"));
+ private static final AthenzPrincipal TENANT_PIPELINE = new AthenzPrincipal(HostedAthenzIdentities.from(new ScrewdriverId("12345")));
+ private static final TenantName TENANT = TenantName.from("mytenant");
+ private static final TenantName TENANT2 = TenantName.from("othertenant");
+ private static final ApplicationName APPLICATION = ApplicationName.from("myapp");
+ private static final URI NO_CONTEXT_PATH = URI.create("/application/v4/");
+ private static final URI TENANT_CONTEXT_PATH = URI.create("/application/v4/tenant/mytenant/");
+ private static final URI APPLICATION_CONTEXT_PATH = URI.create("/application/v4/tenant/mytenant/application/myapp/");
+ private static final URI TENANT2_CONTEXT_PATH = URI.create("/application/v4/tenant/othertenant/");
+ private static final URI APPLICATION2_CONTEXT_PATH = URI.create("/application/v4/tenant/othertenant/application/myapp/");
+
+ private ControllerTester tester;
+ private AthenzRoleFilter filter;
+
+ @Before
+ public void setup() {
+ tester = new ControllerTester();
+ filter = new AthenzRoleFilter(new CorsFilterConfig.Builder().build(),
+ new AthenzFacade(new AthenzClientFactoryMock(tester.athenzDb())),
+ tester.controller());
+
+ tester.athenzDb().hostedOperators.add(HOSTED_OPERATOR.getIdentity());
+ tester.createTenant(TENANT.value(), TENANT_DOMAIN.getName(), null);
+ tester.createApplication(TENANT, APPLICATION.value(), "default", 12345);
+ AthenzDbMock.Domain tenantDomain = tester.athenzDb().domains.get(TENANT_DOMAIN);
+ tenantDomain.admins.add(TENANT_ADMIN.getIdentity());
+ tenantDomain.applications.get(new ApplicationId(APPLICATION.value())).addRoleMember(ApplicationAction.deploy, TENANT_PIPELINE.getIdentity());
+ tester.createTenant(TENANT2.value(), TENANT_DOMAIN2.getName(), null);
+ tester.createApplication(TENANT2, APPLICATION.value(), "default", 42);
+ }
+
+ @Test
+ public void testTranslations() {
+
+ Roles roles = new Roles(tester.controller().system());
+
+ // Hosted operators are always members of the hostedOperator role.
+ assertEquals(Set.of(roles.hostedOperator()),
+ filter.roles(HOSTED_OPERATOR, NO_CONTEXT_PATH));
+
+ assertEquals(Set.of(roles.hostedOperator()),
+ filter.roles(HOSTED_OPERATOR, TENANT_CONTEXT_PATH));
+
+ assertEquals(Set.of(roles.hostedOperator()),
+ filter.roles(HOSTED_OPERATOR, APPLICATION_CONTEXT_PATH));
+
+ // Tenant admins are members of the athenzTenantAdmin role within their tenant subtree.
+ assertEquals(Set.of(roles.everyone()),
+ filter.roles(TENANT_PIPELINE, NO_CONTEXT_PATH));
+
+ assertEquals(Set.of(roles.athenzTenantAdmin(TENANT)),
+ filter.roles(TENANT_ADMIN, TENANT_CONTEXT_PATH));
+
+ assertEquals(Set.of(roles.athenzTenantAdmin(TENANT)),
+ filter.roles(TENANT_ADMIN, APPLICATION_CONTEXT_PATH));
+
+ assertEquals(Set.of(roles.everyone()),
+ filter.roles(TENANT_ADMIN, TENANT2_CONTEXT_PATH));
+
+ assertEquals(Set.of(roles.everyone()),
+ filter.roles(TENANT_ADMIN, APPLICATION2_CONTEXT_PATH));
+
+ // Build services are members of the tenantPipeline role within their application subtree.
+ assertEquals(Set.of(roles.everyone()),
+ filter.roles(TENANT_PIPELINE, NO_CONTEXT_PATH));
+
+ assertEquals(Set.of(roles.everyone()),
+ filter.roles(TENANT_PIPELINE, TENANT_CONTEXT_PATH));
+
+ assertEquals(Set.of(roles.tenantPipeline(TENANT, APPLICATION)),
+ filter.roles(TENANT_PIPELINE, APPLICATION_CONTEXT_PATH));
+
+ assertEquals(Set.of(roles.everyone()),
+ filter.roles(TENANT_PIPELINE, APPLICATION2_CONTEXT_PATH));
+
+ // Unprivileged users are just members of the everyone role.
+ assertEquals(Set.of(roles.everyone()),
+ filter.roles(USER, NO_CONTEXT_PATH));
+
+ assertEquals(Set.of(roles.everyone()),
+ filter.roles(USER, TENANT_CONTEXT_PATH));
+
+ assertEquals(Set.of(roles.everyone()),
+ filter.roles(USER, APPLICATION_CONTEXT_PATH));
+ }
+
+}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolverTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolverTest.java
deleted file mode 100644
index 4628b95ad3c..00000000000
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolverTest.java
+++ /dev/null
@@ -1,120 +0,0 @@
-package com.yahoo.vespa.hosted.controller.restapi.filter;
-
-import com.fasterxml.jackson.databind.ObjectMapper;
-import com.yahoo.config.provision.ApplicationName;
-import com.yahoo.config.provision.TenantName;
-import com.yahoo.vespa.athenz.api.AthenzDomain;
-import com.yahoo.vespa.athenz.api.AthenzPrincipal;
-import com.yahoo.vespa.athenz.api.AthenzService;
-import com.yahoo.vespa.athenz.api.AthenzUser;
-import com.yahoo.vespa.hosted.controller.ControllerTester;
-import com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId;
-import com.yahoo.vespa.hosted.controller.api.identifiers.ScrewdriverId;
-import com.yahoo.vespa.hosted.controller.athenz.ApplicationAction;
-import com.yahoo.vespa.hosted.controller.athenz.HostedAthenzIdentities;
-import com.yahoo.vespa.hosted.controller.athenz.impl.AthenzFacade;
-import com.yahoo.vespa.hosted.controller.athenz.mock.AthenzClientFactoryMock;
-import com.yahoo.vespa.hosted.controller.athenz.mock.AthenzDbMock;
-import com.yahoo.vespa.hosted.controller.role.Context;
-import com.yahoo.vespa.hosted.controller.role.Role;
-import org.junit.Before;
-import org.junit.Test;
-
-import java.util.Optional;
-import java.util.Set;
-
-import static java.util.Collections.emptySet;
-import static org.junit.Assert.assertEquals;
-
-/**
- * @author jonmv
- */
-public class AthenzRoleResolverTest {
-
- private static final ObjectMapper mapper = new ObjectMapper();
-
- private static final AthenzPrincipal USER = new AthenzPrincipal(new AthenzUser("john"));
- private static final AthenzPrincipal HOSTED_OPERATOR = new AthenzPrincipal(new AthenzUser("hosted-operator"));
- private static final AthenzDomain TENANT_DOMAIN = new AthenzDomain("tenantdomain");
- private static final AthenzDomain TENANT_DOMAIN2 = new AthenzDomain("tenantdomain2");
- private static final AthenzPrincipal TENANT_ADMIN = new AthenzPrincipal(new AthenzService(TENANT_DOMAIN, "adminservice"));
- private static final AthenzPrincipal TENANT_PIPELINE = new AthenzPrincipal(HostedAthenzIdentities.from(new ScrewdriverId("12345")));
- private static final TenantName TENANT = TenantName.from("mytenant");
- private static final TenantName TENANT2 = TenantName.from("othertenant");
- private static final ApplicationName APPLICATION = ApplicationName.from("myapp");
- private static final Optional<String> NO_CONTEXT_PATH = Optional.of("/application/v4/");
- private static final Optional<String> TENANT_CONTEXT_PATH = Optional.of("/application/v4/tenant/mytenant/");
- private static final Optional<String> APPLICATION_CONTEXT_PATH = Optional.of("/application/v4/tenant/mytenant/application/myapp/");
- private static final Optional<String> TENANT2_CONTEXT_PATH = Optional.of("/application/v4/tenant/othertenant/");
- private static final Optional<String> APPLICATION2_CONTEXT_PATH = Optional.of("/application/v4/tenant/othertenant/application/myapp/");
-
- private ControllerTester tester;
- private AthenzRoleResolver resolver;
-
- @Before
- public void setup() {
- tester = new ControllerTester();
- resolver = new AthenzRoleResolver(new AthenzFacade(new AthenzClientFactoryMock(tester.athenzDb())),
- tester.controller());
-
- tester.athenzDb().hostedOperators.add(HOSTED_OPERATOR.getIdentity());
- tester.createTenant(TENANT.value(), TENANT_DOMAIN.getName(), null);
- tester.createApplication(TENANT, APPLICATION.value(), "default", 12345);
- AthenzDbMock.Domain tenantDomain = tester.athenzDb().domains.get(TENANT_DOMAIN);
- tenantDomain.admins.add(TENANT_ADMIN.getIdentity());
- tenantDomain.applications.get(new ApplicationId(APPLICATION.value())).addRoleMember(ApplicationAction.deploy, TENANT_PIPELINE.getIdentity());
- tester.createTenant(TENANT2.value(), TENANT_DOMAIN2.getName(), null);
- tester.createApplication(TENANT2, APPLICATION.value(), "default", 42);
- }
-
- @Test
- public void testTranslations() {
-
- // Everyone is member of the everyone role.
- assertEquals(Set.of(Context.unlimitedIn(tester.controller().system())),
- resolver.membership(HOSTED_OPERATOR, APPLICATION_CONTEXT_PATH).contextsFor(Role.everyone));
- assertEquals(Set.of(Context.unlimitedIn(tester.controller().system())),
- resolver.membership(TENANT_ADMIN, TENANT_CONTEXT_PATH).contextsFor(Role.everyone));
- assertEquals(Set.of(Context.unlimitedIn(tester.controller().system())),
- resolver.membership(TENANT_PIPELINE, NO_CONTEXT_PATH).contextsFor(Role.everyone));
- assertEquals(Set.of(Context.unlimitedIn(tester.controller().system())),
- resolver.membership(USER, APPLICATION_CONTEXT_PATH).contextsFor(Role.everyone));
-
- // Only operators are members of the operator role.
- assertEquals(Set.of(Context.unlimitedIn(tester.controller().system())),
- resolver.membership(HOSTED_OPERATOR, TENANT_CONTEXT_PATH).contextsFor(Role.hostedOperator));
- assertEquals(emptySet(),
- resolver.membership(TENANT_ADMIN, NO_CONTEXT_PATH).contextsFor(Role.hostedOperator));
- assertEquals(emptySet(),
- resolver.membership(TENANT_PIPELINE, APPLICATION_CONTEXT_PATH).contextsFor(Role.hostedOperator));
- assertEquals(emptySet(),
- resolver.membership(USER, TENANT_CONTEXT_PATH).contextsFor(Role.hostedOperator));
-
- // Operators and tenant admins are tenant admins of their tenants.
- assertEquals(Set.of(Context.limitedTo(TENANT, tester.controller().system())),
- resolver.membership(HOSTED_OPERATOR, APPLICATION_CONTEXT_PATH).contextsFor(Role.athenzTenantAdmin));
- assertEquals(emptySet(), // TODO this is wrong, but we can't do better until we ask ZMS for roles.
- resolver.membership(TENANT_ADMIN, NO_CONTEXT_PATH).contextsFor(Role.athenzTenantAdmin));
- assertEquals(Set.of(Context.limitedTo(TENANT, tester.controller().system())),
- resolver.membership(TENANT_ADMIN, TENANT_CONTEXT_PATH).contextsFor(Role.athenzTenantAdmin));
- assertEquals(emptySet(),
- resolver.membership(TENANT_ADMIN, TENANT2_CONTEXT_PATH).contextsFor(Role.athenzTenantAdmin));
- assertEquals(emptySet(),
- resolver.membership(TENANT_PIPELINE, APPLICATION_CONTEXT_PATH).contextsFor(Role.athenzTenantAdmin));
- assertEquals(emptySet(),
- resolver.membership(USER, TENANT_CONTEXT_PATH).contextsFor(Role.athenzTenantAdmin));
-
- // Only build services are pipeline operators of their applications.
- assertEquals(emptySet(),
- resolver.membership(HOSTED_OPERATOR, APPLICATION_CONTEXT_PATH).contextsFor(Role.tenantPipeline));
- assertEquals(emptySet(),
- resolver.membership(TENANT_ADMIN, APPLICATION_CONTEXT_PATH).contextsFor(Role.tenantPipeline));
- assertEquals(Set.of(Context.limitedTo(TENANT, APPLICATION, tester.controller().system())),
- resolver.membership(TENANT_PIPELINE, APPLICATION_CONTEXT_PATH).contextsFor(Role.tenantPipeline));
- assertEquals(emptySet(),
- resolver.membership(TENANT_PIPELINE, APPLICATION2_CONTEXT_PATH).contextsFor(Role.tenantPipeline));
- assertEquals(emptySet(),
- resolver.membership(USER, APPLICATION_CONTEXT_PATH).contextsFor(Role.tenantPipeline));
- }
-
-}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilterTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilterTest.java
index 39b08695986..105e10eefd2 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilterTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilterTest.java
@@ -6,13 +6,10 @@ import com.yahoo.application.container.handler.Request;
import com.yahoo.config.provision.SystemName;
import com.yahoo.jdisc.http.HttpRequest.Method;
import com.yahoo.jdisc.http.filter.DiscFilterRequest;
-import com.yahoo.vespa.athenz.api.AthenzIdentity;
-import com.yahoo.vespa.athenz.api.AthenzPrincipal;
-import com.yahoo.vespa.athenz.api.AthenzUser;
import com.yahoo.vespa.hosted.controller.ControllerTester;
+import com.yahoo.vespa.hosted.controller.api.role.Roles;
+import com.yahoo.vespa.hosted.controller.api.role.SecurityContext;
import com.yahoo.vespa.hosted.controller.restapi.ApplicationRequestToDiscFilterRequestWrapper;
-import com.yahoo.vespa.hosted.controller.role.Role;
-import com.yahoo.vespa.hosted.controller.role.RoleMembership;
import org.junit.Test;
import java.io.IOException;
@@ -33,39 +30,42 @@ import static org.junit.Assert.assertTrue;
public class ControllerAuthorizationFilterTest {
private static final ObjectMapper mapper = new ObjectMapper();
- private static AthenzIdentity identity = new AthenzUser("user");
@Test
public void operator() {
ControllerTester tester = new ControllerTester();
- RoleMembership.Resolver operatorResolver = (user, path) -> RoleMembership.in(tester.controller().system())
- .add(Role.hostedOperator)
- .build();
- ControllerAuthorizationFilter filter = createFilter(tester, operatorResolver);
- assertIsAllowed(invokeFilter(filter, createRequest(Method.POST, "/zone/v2/path", identity)));
- assertIsAllowed(invokeFilter(filter, createRequest(Method.PUT, "/application/v4/user", identity)));
- assertIsAllowed(invokeFilter(filter, createRequest(Method.GET, "/zone/v1/path", identity)));
+ Roles roles = new Roles(tester.controller().system());
+ SecurityContext securityContext = new SecurityContext(() -> "operator", Set.of(roles.hostedOperator()));
+ ControllerAuthorizationFilter filter = createFilter(tester);
+
+ assertIsAllowed(invokeFilter(filter, createRequest(Method.POST, "/zone/v2/path", securityContext)));
+ assertIsAllowed(invokeFilter(filter, createRequest(Method.PUT, "/application/v4/user", securityContext)));
+ assertIsAllowed(invokeFilter(filter, createRequest(Method.GET, "/zone/v1/path", securityContext)));
}
@Test
public void unprivileged() {
ControllerTester tester = new ControllerTester();
- RoleMembership.Resolver emptyResolver = (user, path) -> RoleMembership.in(tester.controller().system()).build();
- ControllerAuthorizationFilter filter = createFilter(tester, emptyResolver);
- assertIsForbidden(invokeFilter(filter, createRequest(Method.POST, "/zone/v2/path", identity)));
- assertIsAllowed(invokeFilter(filter, createRequest(Method.PUT, "/application/v4/user", identity)));
- assertIsAllowed(invokeFilter(filter, createRequest(Method.GET, "/zone/v1/path", identity)));
+ Roles roles = new Roles(tester.controller().system());
+ SecurityContext securityContext = new SecurityContext(() -> "user", Set.of(roles.everyone()));
+ ControllerAuthorizationFilter filter = createFilter(tester);
+
+ assertIsForbidden(invokeFilter(filter, createRequest(Method.POST, "/zone/v2/path", securityContext)));
+ assertIsAllowed(invokeFilter(filter, createRequest(Method.PUT, "/application/v4/user", securityContext)));
+ assertIsAllowed(invokeFilter(filter, createRequest(Method.GET, "/zone/v1/path", securityContext)));
}
@Test
public void unprivilegedInPublic() {
ControllerTester tester = new ControllerTester();
tester.zoneRegistry().setSystemName(SystemName.Public);
- RoleMembership.Resolver emptyResolver = (user, path) -> RoleMembership.in(tester.controller().system()).build();
- ControllerAuthorizationFilter filter = createFilter(tester, emptyResolver);
- assertIsForbidden(invokeFilter(filter, createRequest(Method.POST, "/zone/v2/path", identity)));
- assertIsForbidden(invokeFilter(filter, createRequest(Method.PUT, "/application/v4/user", identity)));
- assertIsAllowed(invokeFilter(filter, createRequest(Method.GET, "/zone/v1/path", identity)));
+ Roles roles = new Roles(tester.controller().system());
+ SecurityContext securityContext = new SecurityContext(() -> "user", Set.of(roles.everyone()));
+
+ ControllerAuthorizationFilter filter = createFilter(tester);
+ assertIsForbidden(invokeFilter(filter, createRequest(Method.POST, "/zone/v2/path", securityContext)));
+ assertIsForbidden(invokeFilter(filter, createRequest(Method.PUT, "/application/v4/user", securityContext)));
+ assertIsAllowed(invokeFilter(filter, createRequest(Method.GET, "/zone/v1/path", securityContext)));
}
private static void assertIsAllowed(Optional<AuthorizationResponse> response) {
@@ -79,8 +79,8 @@ public class ControllerAuthorizationFilterTest {
assertEquals("Invalid status code", FORBIDDEN, response.get().statusCode);
}
- private static ControllerAuthorizationFilter createFilter(ControllerTester tester, RoleMembership.Resolver resolver) {
- return new ControllerAuthorizationFilter(resolver, tester.controller(), Set.of("http://localhost"));
+ private static ControllerAuthorizationFilter createFilter(ControllerTester tester) {
+ return new ControllerAuthorizationFilter(tester.controller().system(), Set.of("http://localhost"));
}
private static Optional<AuthorizationResponse> invokeFilter(ControllerAuthorizationFilter filter,
@@ -91,9 +91,9 @@ public class ControllerAuthorizationFilterTest {
.map(response -> new AuthorizationResponse(response.getStatus(), getErrorMessage(responseHandlerMock)));
}
- private static DiscFilterRequest createRequest(Method method, String path, AthenzIdentity identity) {
- Request request = new Request(path, new byte[0], Request.Method.valueOf(method.name()),
- new AthenzPrincipal(identity));
+ private static DiscFilterRequest createRequest(Method method, String path, SecurityContext securityContext) {
+ Request request = new Request(path, new byte[0], Request.Method.valueOf(method.name()), securityContext.principal());
+ request.getAttributes().put(SecurityContext.ATTRIBUTE_NAME, securityContext);
return new ApplicationRequestToDiscFilterRequestWrapper(request);
}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/role/RoleMembershipTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/role/RoleMembershipTest.java
deleted file mode 100644
index 1da5d3764f6..00000000000
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/role/RoleMembershipTest.java
+++ /dev/null
@@ -1,86 +0,0 @@
-// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.controller.role;
-
-import com.yahoo.config.provision.ApplicationName;
-import com.yahoo.config.provision.SystemName;
-import com.yahoo.config.provision.TenantName;
-import org.junit.Test;
-
-import java.net.URI;
-
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-
-/**
- * @author mpolden
- */
-public class RoleMembershipTest {
-
- @Test
- public void operator_membership() {
- RoleMembership roles = RoleMembership.in(SystemName.main)
- .add(Role.hostedOperator)
- .build();
-
- // Operator actions
- assertFalse(roles.allows(Action.create, URI.create("/not/explicitly/defined")));
- assertTrue(roles.allows(Action.create, URI.create("/controller/v1/foo")));
- assertTrue(roles.allows(Action.update, URI.create("/os/v1/bar")));
- assertTrue(roles.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1")));
- assertTrue(roles.allows(Action.update, URI.create("/application/v4/tenant/t2/application/a2")));
- }
-
- @Test
- public void tenant_membership() {
- RoleMembership roles = RoleMembership.in(SystemName.main)
- .add(Role.athenzTenantAdmin).limitedTo(TenantName.from("t1"), ApplicationName.from("a1"))
- .build();
- assertFalse(roles.allows(Action.create, URI.create("/not/explicitly/defined")));
- assertFalse("Deny access to operator API", roles.allows(Action.create, URI.create("/controller/v1/foo")));
- assertFalse("Deny access to other tenant and app", roles.allows(Action.update, URI.create("/application/v4/tenant/t2/application/a2")));
- assertFalse("Deny access to other app", roles.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a2")));
- assertTrue(roles.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1")));
-
- RoleMembership multiContext = RoleMembership.in(SystemName.main)
- .add(Role.athenzTenantAdmin).limitedTo(TenantName.from("t1"), ApplicationName.from("a1"))
- .add(Role.athenzTenantAdmin).limitedTo(TenantName.from("t2"), ApplicationName.from("a2"))
- .build();
- assertFalse("Deny access to other tenant and app", multiContext.allows(Action.update, URI.create("/application/v4/tenant/t3/application/a3")));
- assertTrue(multiContext.allows(Action.update, URI.create("/application/v4/tenant/t2/application/a2")));
- assertTrue(multiContext.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1")));
-
- RoleMembership publicSystem = RoleMembership.in(SystemName.vaas)
- .add(Role.athenzTenantAdmin).limitedTo(TenantName.from("t1"), ApplicationName.from("a1"))
- .build();
- assertFalse(publicSystem.allows(Action.read, URI.create("/controller/v1/foo")));
- assertTrue(multiContext.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1")));
- }
-
- @Test
- public void build_service_membership() {
- RoleMembership roles = RoleMembership.in(SystemName.main)
- .add(Role.tenantPipeline).build();
- assertFalse(roles.allows(Action.create, URI.create("/not/explicitly/defined")));
- assertFalse(roles.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1")));
- assertTrue(roles.allows(Action.create, URI.create("/application/v4/tenant/t1/application/a1/jobreport")));
- assertFalse("No global read access", roles.allows(Action.read, URI.create("/controller/v1/foo")));
- }
-
- @Test
- public void multi_role_membership() {
- RoleMembership roles = RoleMembership.in(SystemName.main)
- .add(Role.athenzTenantAdmin).limitedTo(TenantName.from("t1"), ApplicationName.from("a1"))
- .add(Role.tenantPipeline)
- .add(Role.everyone)
- .build();
- assertFalse(roles.allows(Action.create, URI.create("/not/explicitly/defined")));
- assertFalse(roles.allows(Action.create, URI.create("/controller/v1/foo")));
- assertTrue(roles.allows(Action.create, URI.create("/application/v4/tenant/t1/application/a1/jobreport")));
- assertTrue(roles.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1")));
- assertTrue("Global read access", roles.allows(Action.read, URI.create("/controller/v1/foo")));
- assertTrue("Dashboard read access", roles.allows(Action.read, URI.create("/")));
- assertTrue("Dashboard read access", roles.allows(Action.read, URI.create("/d/nodes")));
- assertTrue("Dashboard read access", roles.allows(Action.read, URI.create("/statuspage/v1/incidents")));
- }
-
-}
diff --git a/document/src/main/java/com/yahoo/document/CollectionDataType.java b/document/src/main/java/com/yahoo/document/CollectionDataType.java
index a73588a710c..c6420b5e71f 100644
--- a/document/src/main/java/com/yahoo/document/CollectionDataType.java
+++ b/document/src/main/java/com/yahoo/document/CollectionDataType.java
@@ -32,7 +32,6 @@ public abstract class CollectionDataType extends DataType {
return type;
}
- @SuppressWarnings("deprecation")
public DataType getNestedType() {
return nestedType;
}
@@ -58,11 +57,7 @@ public abstract class CollectionDataType extends DataType {
return false;
}
CollectionFieldValue cfv = (CollectionFieldValue) value;
- if (equals(cfv.getDataType())) {
- //the field value if of this type:
- return true;
- }
- return false;
+ return equals(cfv.getDataType());
}
@Override
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
index 2773f9d31da..435c8fcdc65 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
@@ -38,7 +38,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
* Converts the given tensor type to a type that is compatible for being used in this update (has only mapped dimensions).
*/
public static TensorType convertDimensionsToMapped(TensorType type) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(type.valueType());
type.dimensions().stream().forEach(dim -> builder.mapped(dim.name()));
return builder.build();
}
diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
index 335cda8e133..981120af145 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
@@ -97,7 +97,7 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
}
public static TensorType extractSparseDimensions(TensorType type) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(type.valueType());
type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name()));
return builder.build();
}
diff --git a/documentgen-test/etc/complex/music4.sd b/documentgen-test/etc/complex/music4.sd
index c8100ba7de2..eab0018360d 100644
--- a/documentgen-test/etc/complex/music4.sd
+++ b/documentgen-test/etc/complex/music4.sd
@@ -4,5 +4,8 @@ search music4 {
field mu4 type string {
}
+ field pos type position {
+
+ }
}
}
diff --git a/documentgen-test/src/test/java/com/yahoo/vespa/config/DocumentGenPluginTest.java b/documentgen-test/src/test/java/com/yahoo/vespa/config/DocumentGenPluginTest.java
index deec438a332..b6a0f165ca6 100644
--- a/documentgen-test/src/test/java/com/yahoo/vespa/config/DocumentGenPluginTest.java
+++ b/documentgen-test/src/test/java/com/yahoo/vespa/config/DocumentGenPluginTest.java
@@ -5,21 +5,53 @@ import com.yahoo.compress.CompressionType;
import com.yahoo.docproc.DocumentProcessor;
import com.yahoo.docproc.Processing;
import com.yahoo.docproc.proxy.ProxyDocument;
-import com.yahoo.document.*;
+import com.yahoo.document.ArrayDataType;
+import com.yahoo.document.DataType;
+import com.yahoo.document.Document;
+import com.yahoo.document.DocumentId;
+import com.yahoo.document.DocumentPut;
+import com.yahoo.document.DocumentType;
+import com.yahoo.document.DocumentTypeManager;
+import com.yahoo.document.Field;
+import com.yahoo.document.Generated;
+import com.yahoo.document.MapDataType;
+import com.yahoo.document.ReferenceDataType;
+import com.yahoo.document.StructDataType;
+import com.yahoo.document.WeightedSetDataType;
import com.yahoo.document.annotation.Annotation;
import com.yahoo.document.annotation.AnnotationType;
import com.yahoo.document.annotation.SpanTree;
import com.yahoo.document.config.DocumentmanagerConfig;
-import com.yahoo.document.datatypes.*;
-import com.yahoo.document.serialization.*;
+import com.yahoo.document.datatypes.Array;
+import com.yahoo.document.datatypes.DoubleFieldValue;
+import com.yahoo.document.datatypes.FieldValue;
+import com.yahoo.document.datatypes.FloatFieldValue;
+import com.yahoo.document.datatypes.IntegerFieldValue;
+import com.yahoo.document.datatypes.LongFieldValue;
+import com.yahoo.document.datatypes.MapFieldValue;
+import com.yahoo.document.datatypes.Raw;
+import com.yahoo.document.datatypes.ReferenceFieldValue;
+import com.yahoo.document.datatypes.StringFieldValue;
+import com.yahoo.document.datatypes.Struct;
+import com.yahoo.document.datatypes.StructuredFieldValue;
+import com.yahoo.document.datatypes.WeightedSet;
+import com.yahoo.document.serialization.DocumentDeserializerFactory;
+import com.yahoo.document.serialization.DocumentSerializer;
+import com.yahoo.document.serialization.DocumentSerializerFactory;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.searchdefinition.derived.Deriver;
import com.yahoo.tensor.Tensor;
import com.yahoo.vespa.document.NodeImpl;
import com.yahoo.vespa.document.dom.DocumentImpl;
-import com.yahoo.vespa.documentgen.test.*;
+import com.yahoo.vespa.documentgen.test.Book;
import com.yahoo.vespa.documentgen.test.Book.Ss0;
import com.yahoo.vespa.documentgen.test.Book.Ss1;
+import com.yahoo.vespa.documentgen.test.Common;
+import com.yahoo.vespa.documentgen.test.ConcreteDocumentFactory;
+import com.yahoo.vespa.documentgen.test.Music;
+import com.yahoo.vespa.documentgen.test.Music3;
+import com.yahoo.vespa.documentgen.test.Music4;
+import com.yahoo.vespa.documentgen.test.Parent;
import com.yahoo.vespa.documentgen.test.annotation.Artist;
import com.yahoo.vespa.documentgen.test.annotation.Date;
import com.yahoo.vespa.documentgen.test.annotation.Emptyannotation;
@@ -32,10 +64,24 @@ import java.lang.Class;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.nio.ByteBuffer;
-import java.util.*;
-
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertFalse;
+import static junit.framework.TestCase.assertNotSame;
import static org.hamcrest.CoreMatchers.instanceOf;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertThat;
+
/**
* Testcases for vespa-documentgen-plugin
@@ -675,7 +721,7 @@ public class DocumentGenPluginTest {
}
if (generated.getAnnotation(com.yahoo.document.Generated.class)==null) return null;
Book book = new Book(d.getId());
- for (Iterator<Map.Entry<Field, FieldValue>>i=d.iterator() ; i.hasNext() ; ) {
+ for (Iterator<Map.Entry<Field, FieldValue>> i = d.iterator(); i.hasNext() ; ) {
Map.Entry<Field, FieldValue> e = i.next();
Field f = e.getKey();
FieldValue fv = e.getValue();
@@ -928,5 +974,12 @@ public class DocumentGenPluginTest {
book.setVector(Tensor.from("{{x:0}:1.0, {x:1}:2.0, {x:2}:3.0}"));
assertEquals("tensor(x{}):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", book.getVector().toString());
}
+
+ @Test
+ public void testPositionType() {
+ Music4 book = new Music4(new DocumentId("doc:music4:0"));
+ book.setPos(new Music4.Position().setX(7).setY(8));
+ assertEquals(new Music4.Position().setX(7).setY(8), book.getPos());
+ }
}
diff --git a/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp b/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp
index ae6166f9d24..a0aeb6b63c9 100644
--- a/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp
+++ b/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp
@@ -40,6 +40,7 @@ assertTensorSpec(const TensorSpec &expSpec, const Tensor &tensor)
struct Fixture
{
Builder builder;
+ Fixture() : builder() {}
};
Tensor::UP
diff --git a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
index 7aa1d71fe9a..b7aa988775d 100644
--- a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
+++ b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
@@ -11,6 +11,7 @@
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/objects/hexdump.h>
#include <ostream>
+#include <vespa/eval/tensor/dense/dense_tensor_view.h>
using namespace vespalib::tensor;
using vespalib::nbostream;
@@ -32,9 +33,7 @@ std::ostream &operator<<(std::ostream &out, const std::vector<uint8_t> &rhs)
}
-namespace vespalib {
-
-namespace tensor {
+namespace vespalib::tensor {
static bool operator==(const Tensor &lhs, const Tensor &rhs)
{
@@ -42,7 +41,6 @@ static bool operator==(const Tensor &lhs, const Tensor &rhs)
}
}
-}
template <class BuilderType>
void
@@ -69,7 +67,7 @@ struct Fixture
Fixture() : _builder() {}
Tensor::UP createTensor(const TensorCells &cells) {
- return vespalib::tensor::TensorFactory::create(cells, _builder);
+ return TensorFactory::create(cells, _builder);
}
Tensor::UP createTensor(const TensorCells &cells, const TensorDimensions &dimensions) {
return TensorFactory::create(cells, dimensions, _builder);
@@ -84,7 +82,7 @@ struct Fixture
auto formatId = wrapStream.getInt1_4Bytes();
ASSERT_EQUAL(formatId, 1u); // sparse format
SparseBinaryFormat::deserialize(wrapStream, builder);
- EXPECT_TRUE(wrapStream.size() == 0);
+ EXPECT_TRUE(wrapStream.empty());
auto ret = builder.build();
checkDeserialize<BuilderType>(stream, *ret);
stream.adjustReadPos(stream.size());
@@ -162,93 +160,129 @@ struct DenseFixture
return ret;
}
void assertSerialized(const ExpBuffer &exp, const DenseTensorCells &rhs) {
+ assertSerialized(exp, SerializeFormat::DOUBLE, rhs);
+ }
+ template <typename T>
+ void assertCellsOnly(const ExpBuffer &exp, const DenseTensorView & rhs) {
+ nbostream a(&exp[0], exp.size());
+ std::vector<T> v;
+ TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(a, v);
+ EXPECT_EQUAL(v.size(), rhs.cellsRef().size());
+ for (size_t i(0); i < v.size(); i++) {
+ EXPECT_EQUAL(v[i], rhs.cellsRef()[i]);
+ }
+ }
+ void assertSerialized(const ExpBuffer &exp, SerializeFormat cellType, const DenseTensorCells &rhs) {
Tensor::UP rhsTensor(createTensor(rhs));
nbostream rhsStream;
- serialize(rhsStream, *rhsTensor);
+ TypedBinaryFormat::serialize(rhsStream, *rhsTensor, cellType);
EXPECT_EQUAL(exp, rhsStream);
auto rhs2 = deserialize(rhsStream);
EXPECT_EQUAL(*rhs2, *rhsTensor);
+
+ assertCellsOnly<float>(exp, dynamic_cast<const DenseTensorView &>(*rhs2));
+ assertCellsOnly<double>(exp, dynamic_cast<const DenseTensorView &>(*rhs2));
}
};
-TEST_F("test tensor serialization for DenseTensor", DenseFixture)
-{
- TEST_DO(f.assertSerialized({ 0x02, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00},
+TEST_F("test tensor serialization for DenseTensor", DenseFixture) {
+ TEST_DO(f.assertSerialized({0x02, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00},
{}));
- TEST_DO(f.assertSerialized({ 0x02, 0x01, 0x01, 0x78, 0x01,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00},
- { {{{"x",0}}, 0} }));
- TEST_DO(f.assertSerialized({ 0x02, 0x02, 0x01, 0x78, 0x01,
- 0x01, 0x79, 0x01,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00 },
- { {{{"x",0},{"y", 0}}, 0} }));
- TEST_DO(f.assertSerialized({ 0x02, 0x01, 0x01, 0x78, 0x02,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x40, 0x08, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00 },
- { {{{"x",1}}, 3} }));
- TEST_DO(f.assertSerialized({ 0x02, 0x02, 0x01, 0x78, 0x01,
- 0x01, 0x79, 0x01,
- 0x40, 0x08, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00 },
- { {{{"x",0},{"y",0}}, 3} }));
- TEST_DO(f.assertSerialized({ 0x02, 0x02, 0x01, 0x78, 0x02,
- 0x01, 0x79, 0x01,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x40, 0x08, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00 },
- { {{{"x",1},{"y",0}}, 3} }));
- TEST_DO(f.assertSerialized({ 0x02, 0x02, 0x01, 0x78, 0x01,
- 0x01, 0x79, 0x04,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x40, 0x08, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00 },
- { {{{"x",0},{"y",3}}, 3} }));
- TEST_DO(f.assertSerialized({ 0x02, 0x02, 0x01, 0x78, 0x03,
- 0x01, 0x79, 0x05,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00,
- 0x40, 0x08, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00 },
- { {{{"x",2}, {"y",4}}, 3} }));
+ TEST_DO(f.assertSerialized({0x02, 0x01, 0x01, 0x78, 0x01,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00},
+ {{{{"x", 0}}, 0}}));
+ TEST_DO(f.assertSerialized({0x02, 0x02, 0x01, 0x78, 0x01,
+ 0x01, 0x79, 0x01,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00},
+ {{{{"x", 0}, {"y", 0}}, 0}}));
+ TEST_DO(f.assertSerialized({0x02, 0x01, 0x01, 0x78, 0x02,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x40, 0x08, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00},
+ {{{{"x", 1}}, 3}}));
+ TEST_DO(f.assertSerialized({0x02, 0x02, 0x01, 0x78, 0x01,
+ 0x01, 0x79, 0x01,
+ 0x40, 0x08, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00},
+ {{{{"x", 0}, {"y", 0}}, 3}}));
+ TEST_DO(f.assertSerialized({0x02, 0x02, 0x01, 0x78, 0x02,
+ 0x01, 0x79, 0x01,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x40, 0x08, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00},
+ {{{{"x", 1}, {"y", 0}}, 3}}));
+ TEST_DO(f.assertSerialized({0x02, 0x02, 0x01, 0x78, 0x01,
+ 0x01, 0x79, 0x04,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x40, 0x08, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00},
+ {{{{"x", 0}, {"y", 3}}, 3}}));
+ TEST_DO(f.assertSerialized({0x02, 0x02, 0x01, 0x78, 0x03,
+ 0x01, 0x79, 0x05,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x40, 0x08, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00},
+ {{{{"x", 2}, {"y", 4}}, 3}}));
+}
+
+TEST_F("test 'float' cells", DenseFixture) {
+ TEST_DO(f.assertSerialized({0x06, 0x01, 0x02, 0x01, 0x78, 0x03,
+ 0x01, 0x79, 0x05,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x40, 0x40, 0x00, 0x00 },
+ SerializeFormat::FLOAT, { {{{"x",2}, {"y",4}}, 3} }));
}
diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h
index 564d6a6b84e..9d95d91ae15 100644
--- a/eval/src/vespa/eval/eval/value_type.h
+++ b/eval/src/vespa/eval/eval/value_type.h
@@ -4,7 +4,6 @@
#include <vespa/vespalib/stllike/string.h>
#include <vector>
-#include <memory>
namespace vespalib::eval {
@@ -36,11 +35,12 @@ public:
};
private:
- Type _type;
+ Type _type;
std::vector<Dimension> _dimensions;
- explicit ValueType(Type type_in)
+ ValueType(Type type_in)
: _type(type_in), _dimensions() {}
+
ValueType(Type type_in, std::vector<Dimension> &&dimensions_in)
: _type(type_in), _dimensions(std::move(dimensions_in)) {}
diff --git a/eval/src/vespa/eval/eval/value_type_spec.cpp b/eval/src/vespa/eval/eval/value_type_spec.cpp
index cf0fb6d493a..229a9201f08 100644
--- a/eval/src/vespa/eval/eval/value_type_spec.cpp
+++ b/eval/src/vespa/eval/eval/value_type_spec.cpp
@@ -6,9 +6,7 @@
#include <vespa/vespalib/util/stringfmt.h>
#include <algorithm>
-namespace vespalib {
-namespace eval {
-namespace value_type {
+namespace vespalib::eval::value_type {
namespace {
@@ -205,6 +203,4 @@ to_spec(const ValueType &type)
return os.str();
}
-} // namespace vespalib::eval::value_type
-} // namespace vespalib::eval
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/eval/value_type_spec.h b/eval/src/vespa/eval/eval/value_type_spec.h
index 76d50834ae8..f2609f59f32 100644
--- a/eval/src/vespa/eval/eval/value_type_spec.h
+++ b/eval/src/vespa/eval/eval/value_type_spec.h
@@ -4,15 +4,11 @@
#include "value_type.h"
-namespace vespalib {
-namespace eval {
-namespace value_type {
+namespace vespalib::eval::value_type {
ValueType parse_spec(const char *pos_in, const char *end_in, const char *&pos_out);
ValueType from_spec(const vespalib::string &str);
vespalib::string to_spec(const ValueType &type);
-} // namespace vespalib::eval::value_type
-} // namespace vespalib::eval
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/default_tensor.h b/eval/src/vespa/eval/tensor/default_tensor.h
index 202b482e300..456d3333295 100644
--- a/eval/src/vespa/eval/tensor/default_tensor.h
+++ b/eval/src/vespa/eval/tensor/default_tensor.h
@@ -5,13 +5,11 @@
#include "sparse/sparse_tensor.h"
#include "sparse/sparse_tensor_builder.h"
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
struct DefaultTensor {
using type = SparseTensor;
using builder = SparseTensorBuilder;
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index 554953288e1..5a16511fe71 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -26,8 +26,7 @@
#include <vespa/log/log.h>
LOG_SETUP(".eval.tensor.default_tensor_engine");
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
using eval::Aggr;
using eval::Aggregator;
@@ -390,5 +389,4 @@ DefaultTensorEngine::rename(const Value &a, const std::vector<vespalib::string>
//-----------------------------------------------------------------------------
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.h b/eval/src/vespa/eval/tensor/default_tensor_engine.h
index 755bdcf6a9d..b7a9e4d43e7 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.h
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.h
@@ -4,8 +4,7 @@
#include <vespa/eval/eval/tensor_engine.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/**
* This is a tensor engine implementation wrapping the default tensor
@@ -34,5 +33,4 @@ public:
const Value &rename(const Value &a, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash) const override;
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp
index e775385b623..c183e5c1db3 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp
@@ -75,6 +75,8 @@ DenseTensor::DenseTensor(eval::ValueType &&type_in,
checkCellsSize(*this);
}
+DenseTensor::~DenseTensor() = default;
+
bool
DenseTensor::operator==(const DenseTensor &rhs) const
{
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor.h b/eval/src/vespa/eval/tensor/dense/dense_tensor.h
index 0da5f570674..3795831c914 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor.h
@@ -2,10 +2,6 @@
#pragma once
-#include <vespa/eval/tensor/tensor.h>
-#include <vespa/eval/tensor/types.h>
-#include <vespa/eval/eval/value_type.h>
-#include "dense_tensor_cells_iterator.h"
#include "dense_tensor_view.h"
namespace vespalib::tensor {
@@ -17,20 +13,16 @@ namespace vespalib::tensor {
class DenseTensor : public DenseTensorView
{
public:
- typedef std::unique_ptr<DenseTensor> UP;
- using Cells = std::vector<double>;
-
-private:
- eval::ValueType _type;
- Cells _cells;
-
-public:
DenseTensor();
- ~DenseTensor() {}
+ ~DenseTensor() override;
DenseTensor(const eval::ValueType &type_in, const Cells &cells_in);
DenseTensor(const eval::ValueType &type_in, Cells &&cells_in);
DenseTensor(eval::ValueType &&type_in, Cells &&cells_in);
bool operator==(const DenseTensor &rhs) const;
+private:
+ eval::ValueType _type;
+ Cells _cells;
+
};
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
index 25478510587..fa1e59c87db 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
@@ -74,7 +74,7 @@ apply(const DenseTensorView &lhs, const Tensor &rhs, Function &&func)
}
const DenseTensor *dense = dynamic_cast<const DenseTensor *>(&rhs);
if (dense) {
- return apply(lhs, DenseTensorView(*dense), func);
+ return apply(lhs, *dense, func);
}
return Tensor::UP();
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp
index 5d52e5f6e0e..cd4738cf1ee 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp
@@ -6,7 +6,6 @@
#include <limits>
#include <algorithm>
-
using vespalib::IllegalArgumentException;
using vespalib::make_string;
@@ -83,8 +82,7 @@ DenseTensorBuilder::calculateCellAddress()
const auto &dim = _dimensions[i];
if (label == UNDEFINED_LABEL) {
throw IllegalArgumentException(make_string("Label for dimension '%s' is undefined. "
- "Expected a value in the range [0, %u>",
- dim.name.c_str(), dim.size));
+ "Expected a value in the range [0, %u>", dim.name.c_str(), dim.size));
}
result += (label * multiplier);
multiplier *= dim.size;
@@ -102,12 +100,10 @@ DenseTensorBuilder::DenseTensorBuilder()
{
}
-DenseTensorBuilder::~DenseTensorBuilder() {
-}
+DenseTensorBuilder::~DenseTensorBuilder() = default;
DenseTensorBuilder::Dimension
-DenseTensorBuilder::defineDimension(const vespalib::string &dimension,
- size_t dimensionSize)
+DenseTensorBuilder::defineDimension(const vespalib::string &dimension, size_t dimensionSize)
{
auto itr = _dimensionsEnum.find(dimension);
if (itr != _dimensionsEnum.end()) {
@@ -135,8 +131,7 @@ DenseTensorBuilder::addLabel(Dimension dimension, size_t label)
Dimension mappedDimension = _dimensionsMapping[dimension];
const auto &dim = _dimensions[mappedDimension];
validateLabelInRange(label, dim.size, dim.name);
- validateLabelNotSpecified(_addressBuilder[mappedDimension],
- dim.name);
+ validateLabelNotSpecified(_addressBuilder[mappedDimension], dim.name);
_addressBuilder[mappedDimension] = label;
return *this;
}
@@ -154,14 +149,13 @@ DenseTensorBuilder::addCell(double value)
return *this;
}
-Tensor::UP
+std::unique_ptr<DenseTensor>
DenseTensorBuilder::build()
{
if (_cells.empty()) {
allocateCellsStorage();
}
- Tensor::UP result = std::make_unique<DenseTensor>(makeValueType(std::move(_dimensions)),
- std::move(_cells));
+ auto result = std::make_unique<DenseTensor>(makeValueType(std::move(_dimensions)), std::move(_cells));
_dimensionsEnum.clear();
_dimensions.clear();
DenseTensor::Cells().swap(_cells);
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h
index 3969a9335b8..05cd88b1319 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h
@@ -15,12 +15,11 @@ class DenseTensorBuilder
{
public:
using Dimension = TensorBuilder::Dimension;
-
private:
vespalib::hash_map<vespalib::string, size_t> _dimensionsEnum;
std::vector<eval::ValueType::Dimension> _dimensions;
- DenseTensor::Cells _cells;
- std::vector<size_t> _addressBuilder;
+ DenseTensor::Cells _cells;
+ std::vector<size_t> _addressBuilder;
std::vector<Dimension> _dimensionsMapping;
void allocateCellsStorage();
@@ -34,7 +33,7 @@ public:
Dimension defineDimension(const vespalib::string &dimension, size_t dimensionSize);
DenseTensorBuilder &addLabel(Dimension dimension, size_t label);
DenseTensorBuilder &addCell(double value);
- Tensor::UP build();
+ std::unique_ptr<DenseTensor> build();
};
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h
index 447d8a4f805..caf92d6c8c7 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h
@@ -2,10 +2,7 @@
#pragma once
-#include <vespa/eval/tensor/tensor.h>
-#include <vespa/eval/tensor/types.h>
#include <vespa/eval/eval/value_type.h>
-#include <vespa/eval/tensor/tensor.h>
#include <vespa/vespalib/util/arrayref.h>
namespace vespalib::tensor {
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.hpp
index 8480e7418e1..98db89dd2a7 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.hpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.hpp
@@ -25,7 +25,7 @@ public:
~DimensionReducer();
template <typename Function>
- DenseTensor::UP
+ std::unique_ptr<DenseTensorView>
reduceCells(CellsRef cellsIn, Function &&func) {
auto itr_in = cellsIn.cbegin();
auto itr_out = _cellsResult.begin();
@@ -54,7 +54,7 @@ public:
namespace {
template <typename Function>
-DenseTensor::UP
+std::unique_ptr<DenseTensorView>
reduce(const DenseTensorView &tensor, const vespalib::string &dimensionToRemove, Function &&func)
{
DimensionReducer reducer(tensor.fast_type(), dimensionToRemove);
@@ -70,9 +70,9 @@ reduce(const DenseTensorView &tensor, const std::vector<vespalib::string> &dimen
if (dimensions.size() == 1) {
return reduce(tensor, dimensions[0], func);
} else if (dimensions.size() > 0) {
- DenseTensor::UP result = reduce(tensor, dimensions[0], func);
+ std::unique_ptr<DenseTensorView> result = reduce(tensor, dimensions[0], func);
for (size_t i = 1; i < dimensions.size(); ++i) {
- DenseTensor::UP tmpResult = reduce(DenseTensorView(*result), dimensions[i], func);
+ std::unique_ptr<DenseTensorView> tmpResult = reduce(*result, dimensions[i], func);
result = std::move(tmpResult);
}
return result;
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
index 164ec042384..73b2e7b3ffb 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
@@ -134,14 +134,6 @@ bool sameCells(DenseTensorView::CellsRef lhs, DenseTensorView::CellsRef rhs)
}
-
-DenseTensorView::DenseTensorView(const DenseTensor &rhs)
- : _typeRef(rhs.fast_type()),
- _cellsRef(rhs.cellsRef())
-{
-}
-
-
bool
DenseTensorView::operator==(const DenseTensorView &rhs) const
{
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
index 11ed9639cc6..09b6b72375e 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
@@ -2,15 +2,11 @@
#pragma once
-#include <vespa/eval/tensor/tensor.h>
-#include <vespa/eval/tensor/types.h>
-#include <vespa/eval/eval/value_type.h>
#include "dense_tensor_cells_iterator.h"
+#include <vespa/eval/tensor/tensor.h>
namespace vespalib::tensor {
-class DenseTensor;
-
/**
* A view to a dense tensor where all dimensions are indexed.
* Tensor cells are stored in an underlying array according to the order of the dimensions.
@@ -23,26 +19,15 @@ public:
using CellsIterator = DenseTensorCellsIterator;
using Address = std::vector<eval::ValueType::Dimension::size_type>;
-private:
- const eval::ValueType &_typeRef;
- Tensor::UP reduce_all(join_fun_t op, const std::vector<vespalib::string> &dimensions) const;
-protected:
- CellsRef _cellsRef;
-
- void initCellsRef(CellsRef cells_in) {
- _cellsRef = cells_in;
- }
-
-public:
- explicit DenseTensorView(const DenseTensor &rhs);
DenseTensorView(const eval::ValueType &type_in, CellsRef cells_in)
: _typeRef(type_in),
_cellsRef(cells_in)
{}
- DenseTensorView(const eval::ValueType &type_in)
- : _typeRef(type_in),
- _cellsRef()
+ explicit DenseTensorView(const eval::ValueType &type_in)
+ : _typeRef(type_in),
+ _cellsRef()
{}
+
const eval::ValueType &fast_type() const { return _typeRef; }
const CellsRef &cellsRef() const { return _cellsRef; }
bool operator==(const DenseTensorView &rhs) const;
@@ -60,6 +45,15 @@ public:
Tensor::UP clone() const override;
eval::TensorSpec toSpec() const override;
void accept(TensorVisitor &visitor) const override;
+protected:
+ void initCellsRef(CellsRef cells_in) {
+ _cellsRef = cells_in;
+ }
+private:
+ Tensor::UP reduce_all(join_fun_t op, const std::vector<vespalib::string> &dimensions) const;
+
+ const eval::ValueType &_typeRef;
+ CellsRef _cellsRef;
};
}
diff --git a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h
index 2132f861896..260e71b6f76 100644
--- a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h
+++ b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h
@@ -44,7 +44,7 @@ public:
MutableDenseTensorView(eval::ValueType type_in);
MutableDenseTensorView(eval::ValueType type_in, CellsRef cells_in);
void setCells(CellsRef cells_in) {
- _cellsRef = cells_in;
+ initCellsRef(cells_in);
}
void setUnboundDimensions(const uint32_t *unboundDimSizeBegin, const uint32_t *unboundDimSizeEnd) {
_concreteType.setUnboundDimensions(unboundDimSizeBegin, unboundDimSizeEnd);
diff --git a/eval/src/vespa/eval/tensor/join_tensors.h b/eval/src/vespa/eval/tensor/join_tensors.h
index 86e5913d8f5..271a6b0195d 100644
--- a/eval/src/vespa/eval/tensor/join_tensors.h
+++ b/eval/src/vespa/eval/tensor/join_tensors.h
@@ -5,8 +5,7 @@
#include "tensor.h"
#include "direct_tensor_builder.h"
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
/*
* Join the cells of two tensors.
@@ -44,5 +43,4 @@ joinTensorsNegated(const TensorImplType &lhs,
return builder.build();
}
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/serialization/common.h b/eval/src/vespa/eval/tensor/serialization/common.h
new file mode 100644
index 00000000000..40b1840be6e
--- /dev/null
+++ b/eval/src/vespa/eval/tensor/serialization/common.h
@@ -0,0 +1,9 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+namespace vespalib::tensor {
+
+enum class SerializeFormat {FLOAT, DOUBLE};
+
+}
diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp
index feb811a92de..4b1ccc8db5d 100644
--- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp
@@ -3,48 +3,48 @@
#include "dense_binary_format.h"
#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/vespalib/objects/nbostream.h>
+#include <vespa/vespalib/util/exceptions.h>
#include <cassert>
using vespalib::nbostream;
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
+
+using Dimension = eval::ValueType::Dimension;
+
namespace {
eval::ValueType
-makeValueType(std::vector<eval::ValueType::Dimension> &&dimensions) {
+makeValueType(std::vector<Dimension> &&dimensions) {
return (dimensions.empty() ?
eval::ValueType::double_type() :
eval::ValueType::tensor_type(std::move(dimensions)));
}
-}
-
-void
-DenseBinaryFormat::serialize(nbostream &stream, const DenseTensorView &tensor)
-{
- stream.putInt1_4Bytes(tensor.fast_type().dimensions().size());
+size_t
+encodeDimensions(nbostream &stream, const eval::ValueType & type) {
+ stream.putInt1_4Bytes(type.dimensions().size());
size_t cellsSize = 1;
- for (const auto &dimension : tensor.fast_type().dimensions()) {
+ for (const auto &dimension : type.dimensions()) {
stream.writeSmallString(dimension.name);
stream.putInt1_4Bytes(dimension.size);
cellsSize *= dimension.size;
}
- DenseTensorView::CellsRef cells = tensor.cellsRef();
- assert(cells.size() == cellsSize);
+ return cellsSize;
+}
+
+template<typename T>
+void
+encodeCells(nbostream &stream, DenseTensorView::CellsRef cells) {
for (const auto &value : cells) {
- stream << value;
+ stream << static_cast<T>(value);
}
}
-
-std::unique_ptr<DenseTensor>
-DenseBinaryFormat::deserialize(nbostream &stream)
-{
+size_t
+decodeDimensions(nbostream & stream, std::vector<Dimension> & dimensions) {
vespalib::string dimensionName;
- std::vector<eval::ValueType::Dimension> dimensions;
- DenseTensor::Cells cells;
size_t dimensionsSize = stream.getInt1_4Bytes();
size_t dimensionSize;
size_t cellsSize = 1;
@@ -54,16 +54,76 @@ DenseBinaryFormat::deserialize(nbostream &stream)
dimensions.emplace_back(dimensionName, dimensionSize);
cellsSize *= dimensionSize;
}
- cells.reserve(cellsSize);
- double cellValue = 0.0;
+ return cellsSize;
+}
+
+template<typename T, typename V>
+void
+decodeCells(nbostream &stream, size_t cellsSize, V & cells) {
+ T cellValue = 0.0;
for (size_t i = 0; i < cellsSize; ++i) {
stream >> cellValue;
cells.emplace_back(cellValue);
}
- return std::make_unique<DenseTensor>(makeValueType(std::move(dimensions)),
- std::move(cells));
}
+template <typename V>
+void decodeCells(SerializeFormat format, nbostream &stream, size_t cellsSize, V & cells)
+{
+ switch (format) {
+ case SerializeFormat::DOUBLE:
+ decodeCells<double>(stream, cellsSize, cells);
+ break;
+ case SerializeFormat::FLOAT:
+ decodeCells<float>(stream, cellsSize, cells);
+ break;
+ }
+}
+
+}
-} // namespace vespalib::tensor
-} // namespace vespalib
+void
+DenseBinaryFormat::serialize(nbostream &stream, const DenseTensorView &tensor)
+{
+ size_t cellsSize = encodeDimensions(stream, tensor.fast_type());
+
+ DenseTensorView::CellsRef cells = tensor.cellsRef();
+ assert(cells.size() == cellsSize);
+ switch (_format) {
+ case SerializeFormat::DOUBLE:
+ encodeCells<double>(stream, cells);
+ break;
+ case SerializeFormat::FLOAT:
+ encodeCells<float>(stream, cells);
+ break;
+ }
+}
+
+std::unique_ptr<DenseTensor>
+DenseBinaryFormat::deserialize(nbostream &stream)
+{
+ std::vector<Dimension> dimensions;
+ size_t cellsSize = decodeDimensions(stream,dimensions);
+ DenseTensor::Cells cells;
+ cells.reserve(cellsSize);
+
+ decodeCells(_format, stream, cellsSize, cells);
+
+ return std::make_unique<DenseTensor>(makeValueType(std::move(dimensions)), std::move(cells));
+}
+
+template <typename T>
+void
+DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<T> & cells)
+{
+ std::vector<Dimension> dimensions;
+ size_t cellsSize = decodeDimensions(stream,dimensions);
+ cells.clear();
+ cells.reserve(cellsSize);
+ decodeCells(_format, stream, cellsSize, cells);
+}
+
+template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<double> & cells);
+template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<float> & cells);
+
+}
diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h
index 8019648ffcb..f9847d37784 100644
--- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h
+++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h
@@ -2,12 +2,13 @@
#pragma once
+#include "common.h"
#include <memory>
-namespace vespalib {
+#include <vector>
-class nbostream;
+namespace vespalib { class nbostream; }
-namespace tensor {
+namespace vespalib::tensor {
class DenseTensor;
class DenseTensorView;
@@ -18,9 +19,15 @@ class DenseTensorView;
class DenseBinaryFormat
{
public:
- static void serialize(nbostream &stream, const DenseTensorView &tensor);
- static std::unique_ptr<DenseTensor> deserialize(nbostream &stream);
+ DenseBinaryFormat(SerializeFormat format) : _format(format) { }
+ void serialize(nbostream &stream, const DenseTensorView &tensor);
+ std::unique_ptr<DenseTensor> deserialize(nbostream &stream);
+
+ // This is a temporary method untill we get full support for typed tensors
+ template <typename T>
+ void deserializeCellsOnly(nbostream &stream, std::vector<T> & cells);
+private:
+ SerializeFormat _format;
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/serialization/slime_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/slime_binary_format.cpp
index 7ae3957dc0f..ece3c2e4a07 100644
--- a/eval/src/vespa/eval/tensor/serialization/slime_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/slime_binary_format.cpp
@@ -10,8 +10,7 @@
#include <vespa/vespalib/data/slime/slime.h>
#include <vespa/vespalib/data/memory.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
using slime::Inserter;
@@ -58,13 +57,10 @@ SlimeBinaryFormatSerializer::SlimeBinaryFormatSerializer(Inserter &inserter)
}
-SlimeBinaryFormatSerializer::~SlimeBinaryFormatSerializer()
-{
-}
+SlimeBinaryFormatSerializer::~SlimeBinaryFormatSerializer() = default;
void
-SlimeBinaryFormatSerializer::visit(const TensorAddress &address,
- double value)
+SlimeBinaryFormatSerializer::visit(const TensorAddress &address, double value)
{
Cursor &cellCursor = _cells.addObject();
writeTensorAddress(cellCursor, address);
@@ -101,5 +97,4 @@ SlimeBinaryFormat::serialize(const Tensor &tensor)
}
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/serialization/slime_binary_format.h b/eval/src/vespa/eval/tensor/serialization/slime_binary_format.h
index f1366c64e2c..c9e9ff2c3e9 100644
--- a/eval/src/vespa/eval/tensor/serialization/slime_binary_format.h
+++ b/eval/src/vespa/eval/tensor/serialization/slime_binary_format.h
@@ -4,13 +4,11 @@
#include <memory>
-namespace vespalib {
+namespace vespalib { class Slime; }
-class Slime;
+namespace vespalib::slime { struct Inserter; }
-namespace slime { struct Inserter; }
-
-namespace tensor {
+namespace vespalib::tensor {
class Tensor;
class TensorBuilder;
@@ -25,5 +23,4 @@ public:
static std::unique_ptr<Slime> serialize(const Tensor &tensor);
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp
index bd0c5b25f93..79d1aaa83a8 100644
--- a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp
@@ -11,8 +11,7 @@
using vespalib::nbostream;
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
namespace {
@@ -59,13 +58,10 @@ SparseBinaryFormatSerializer::SparseBinaryFormatSerializer()
}
-SparseBinaryFormatSerializer::~SparseBinaryFormatSerializer()
-{
-}
+SparseBinaryFormatSerializer::~SparseBinaryFormatSerializer() = default;
void
-SparseBinaryFormatSerializer::visit(const TensorAddress &address,
- double value)
+SparseBinaryFormatSerializer::visit(const TensorAddress &address, double value)
{
++_numCells;
writeTensorAddress(_cells, _type, address);
@@ -74,8 +70,7 @@ SparseBinaryFormatSerializer::visit(const TensorAddress &address,
void
-SparseBinaryFormatSerializer::serialize(nbostream &stream,
- const Tensor &tensor)
+SparseBinaryFormatSerializer::serialize(nbostream &stream, const Tensor &tensor)
{
_type = tensor.type();
tensor.accept(*this);
@@ -121,5 +116,4 @@ SparseBinaryFormat::deserialize(nbostream &stream, TensorBuilder &builder)
}
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h
index db05574dfce..89f6947ad43 100644
--- a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h
+++ b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h
@@ -2,11 +2,9 @@
#pragma once
-namespace vespalib {
+namespace vespalib { class nbostream; }
-class nbostream;
-
-namespace tensor {
+namespace vespalib::tensor {
class Tensor;
class TensorBuilder;
@@ -21,5 +19,4 @@ public:
static void deserialize(nbostream &stream, TensorBuilder &builder);
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
index fe35ce4c831..4ca037e82a4 100644
--- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
@@ -11,20 +11,64 @@
#include <vespa/eval/tensor/wrapped_simple_tensor.h>
#include <vespa/log/log.h>
+#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/util/exceptions.h>
+
LOG_SETUP(".eval.tensor.serialization.typed_binary_format");
using vespalib::nbostream;
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
+
+namespace {
+
+constexpr uint32_t SPARSE_BINARY_FORMAT_TYPE = 1u;
+constexpr uint32_t DENSE_BINARY_FORMAT_TYPE = 2u;
+constexpr uint32_t MIXED_BINARY_FORMAT_TYPE = 3u;
+constexpr uint32_t SPARSE_BINARY_FORMAT_WITH_CELLTYPE = 5u; //Future
+constexpr uint32_t DENSE_BINARY_FORMAT_WITH_CELLTYPE = 6u;
+constexpr uint32_t MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7u; //Future
+
+constexpr uint32_t DOUBLE_VALUE_TYPE = 0;
+constexpr uint32_t FLOAT_VALUE_TYPE = 1;
+
+uint32_t
+format2Encoding(SerializeFormat format) {
+ switch (format) {
+ case SerializeFormat::DOUBLE:
+ return DOUBLE_VALUE_TYPE;
+ case SerializeFormat::FLOAT:
+ return FLOAT_VALUE_TYPE;
+ }
+ abort();
+}
+
+SerializeFormat
+encoding2Format(uint32_t serializedType) {
+ switch (serializedType) {
+ case DOUBLE_VALUE_TYPE:
+ return SerializeFormat::DOUBLE;
+ case FLOAT_VALUE_TYPE:
+ return SerializeFormat::FLOAT;
+ default:
+ throw IllegalArgumentException(make_string("Received unknown tensor value type = %u. Only 0(double), or 1(float) are legal.", serializedType));
+ }
+}
+}
void
-TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor)
+TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor, SerializeFormat format)
{
if (auto denseTensor = dynamic_cast<const DenseTensorView *>(&tensor)) {
- stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE);
- DenseBinaryFormat::serialize(stream, *denseTensor);
+ if (format != SerializeFormat::DOUBLE) {
+ stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_WITH_CELLTYPE);
+ stream.putInt1_4Bytes(format2Encoding(format));
+ DenseBinaryFormat(format).serialize(stream, *denseTensor);
+ } else {
+ stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE);
+ DenseBinaryFormat(SerializeFormat::DOUBLE).serialize(stream, *denseTensor);
+ }
} else if (auto wrapped = dynamic_cast<const WrappedSimpleTensor *>(&tensor)) {
eval::SimpleTensor::encode(wrapped->get(), stream);
} else {
@@ -45,15 +89,33 @@ TypedBinaryFormat::deserialize(nbostream &stream)
return builder.build();
}
if (formatId == DENSE_BINARY_FORMAT_TYPE) {
- return DenseBinaryFormat::deserialize(stream);
+ return DenseBinaryFormat(SerializeFormat::DOUBLE).deserialize(stream);
+ }
+ if (formatId == DENSE_BINARY_FORMAT_WITH_CELLTYPE) {
+ return DenseBinaryFormat(encoding2Format(stream.getInt1_4Bytes())).deserialize(stream);
}
if (formatId == MIXED_BINARY_FORMAT_TYPE) {
stream.adjustReadPos(read_pos - stream.rp());
return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::decode(stream));
}
- LOG_ABORT("should not be reached");
+ abort();
}
+template <typename T>
+void
+TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> & cells)
+{
+ auto formatId = stream.getInt1_4Bytes();
+ if (formatId == DENSE_BINARY_FORMAT_TYPE) {
+ return DenseBinaryFormat(SerializeFormat::DOUBLE).deserializeCellsOnly(stream, cells);
+ }
+ if (formatId == DENSE_BINARY_FORMAT_WITH_CELLTYPE) {
+ return DenseBinaryFormat(encoding2Format(stream.getInt1_4Bytes())).deserializeCellsOnly(stream, cells);
+ }
+ abort();
+}
-} // namespace vespalib::tensor
-} // namespace vespalib
+template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<double> & cells);
+template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<float> & cells);
+
+}
diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h
index c655210907f..717d51effef 100644
--- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h
+++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h
@@ -2,30 +2,32 @@
#pragma once
+#include "common.h"
#include <memory>
-#include <cstdint>
+#include <vector>
-namespace vespalib {
+namespace vespalib { class nbostream; }
-class nbostream;
-
-namespace tensor {
+namespace vespalib::tensor {
class Tensor;
-class TensorBuilder;
/**
* Class for serializing a tensor.
*/
class TypedBinaryFormat
{
- static constexpr uint32_t SPARSE_BINARY_FORMAT_TYPE = 1u;
- static constexpr uint32_t DENSE_BINARY_FORMAT_TYPE = 2u;
- static constexpr uint32_t MIXED_BINARY_FORMAT_TYPE = 3u;
public:
- static void serialize(nbostream &stream, const Tensor &tensor);
+ static void serialize(nbostream &stream, const Tensor &tensor, SerializeFormat format);
+ static void serialize(nbostream &stream, const Tensor &tensor) {
+ serialize(stream, tensor, SerializeFormat::DOUBLE);
+ }
+
static std::unique_ptr<Tensor> deserialize(nbostream &stream);
+
+ // This is a temporary method until we get full support for typed tensors
+ template <typename T>
+ static void deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> & cells);
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/tensor.cpp b/eval/src/vespa/eval/tensor/tensor.cpp
index 8715a864f68..51c94aab5b0 100644
--- a/eval/src/vespa/eval/tensor/tensor.cpp
+++ b/eval/src/vespa/eval/tensor/tensor.cpp
@@ -4,8 +4,7 @@
#include "default_tensor_engine.h"
#include <sstream>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
Tensor::Tensor()
: eval::Tensor(DefaultTensorEngine::ref())
@@ -34,5 +33,4 @@ operator<<(std::ostream &out, const Tensor &value)
return out;
}
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/tensor.h b/eval/src/vespa/eval/tensor/tensor.h
index 4061ed9c115..edf5fa710e3 100644
--- a/eval/src/vespa/eval/tensor/tensor.h
+++ b/eval/src/vespa/eval/tensor/tensor.h
@@ -9,9 +9,8 @@
#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/eval/eval/value_type.h>
-namespace vespalib {
-namespace eval { struct BinaryOperation; }
-namespace tensor {
+namespace vespalib::eval { struct BinaryOperation; }
+namespace vespalib::tensor {
class TensorVisitor;
class CellValues;
@@ -66,5 +65,4 @@ public:
std::ostream &operator<<(std::ostream &out, const Tensor &value);
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/tensor_address.cpp b/eval/src/vespa/eval/tensor/tensor_address.cpp
index afadcf2c668..a68fc5d3353 100644
--- a/eval/src/vespa/eval/tensor/tensor_address.cpp
+++ b/eval/src/vespa/eval/tensor/tensor_address.cpp
@@ -4,19 +4,18 @@
#include <algorithm>
#include <ostream>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
const vespalib::string TensorAddress::Element::UNDEFINED_LABEL = "(undefined)";
-TensorAddress::Element::~Element() {}
+TensorAddress::Element::~Element() = default;
TensorAddress::TensorAddress()
: _elements()
{
}
-TensorAddress::~TensorAddress() {}
+TensorAddress::~TensorAddress() = default;
TensorAddress::TensorAddress(const Elements &elements_in)
: _elements(elements_in)
@@ -87,5 +86,4 @@ operator<<(std::ostream &out, const TensorAddress &value)
return out;
}
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/tensor_address_element_iterator.h b/eval/src/vespa/eval/tensor/tensor_address_element_iterator.h
index e413712362f..01710105840 100644
--- a/eval/src/vespa/eval/tensor/tensor_address_element_iterator.h
+++ b/eval/src/vespa/eval/tensor/tensor_address_element_iterator.h
@@ -2,11 +2,10 @@
#pragma once
-#include <vespa/vespalib/stllike/hash_set.h>
+#include <vespa/vespalib/stllike/string.h>
namespace vespalib::tensor {
-using DimensionsSet = vespalib::hash_set<vespalib::stringref>;
/**
* An iterator for tensor address elements used to simplify 3-way merge
diff --git a/eval/src/vespa/eval/tensor/tensor_builder.h b/eval/src/vespa/eval/tensor/tensor_builder.h
index 05238b27df5..30eef5f9c54 100644
--- a/eval/src/vespa/eval/tensor/tensor_builder.h
+++ b/eval/src/vespa/eval/tensor/tensor_builder.h
@@ -4,8 +4,7 @@
#include <vespa/vespalib/stllike/string.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
class Tensor;
@@ -30,5 +29,4 @@ public:
virtual std::unique_ptr<Tensor> build() = 0;
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/tensor_factory.cpp b/eval/src/vespa/eval/tensor/tensor_factory.cpp
index f88ae22c083..0b7fa3b9c2e 100644
--- a/eval/src/vespa/eval/tensor/tensor_factory.cpp
+++ b/eval/src/vespa/eval/tensor/tensor_factory.cpp
@@ -5,12 +5,10 @@
#include "tensor_builder.h"
#include <vespa/eval/tensor/dense/dense_tensor_builder.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
std::unique_ptr<Tensor>
-TensorFactory::create(const TensorCells &cells,
- TensorBuilder &builder) {
+TensorFactory::create(const TensorCells &cells, TensorBuilder &builder) {
for (const auto &cell : cells) {
for (const auto &addressElem : cell.first) {
const auto &dimension = addressElem.first;
@@ -30,9 +28,7 @@ TensorFactory::create(const TensorCells &cells,
std::unique_ptr<Tensor>
-TensorFactory::create(const TensorCells &cells,
- const TensorDimensions &dimensions,
- TensorBuilder &builder) {
+TensorFactory::create(const TensorCells &cells, const TensorDimensions &dimensions, TensorBuilder &builder) {
for (const auto &dimension : dimensions) {
builder.define_dimension(dimension);
}
@@ -47,17 +43,12 @@ TensorFactory::createDense(const DenseTensorCells &cells)
DenseTensorBuilder builder;
for (const auto &cell : cells) {
for (const auto &addressElem : cell.first) {
- dimensionSizes[addressElem.first] =
- std::max(dimensionSizes[addressElem.first],
- (addressElem.second + 1));
+ dimensionSizes[addressElem.first] = std::max(dimensionSizes[addressElem.first], (addressElem.second + 1));
}
}
- std::map<std::string,
- typename DenseTensorBuilder::Dimension> dimensionEnums;
+ std::map<std::string, typename DenseTensorBuilder::Dimension> dimensionEnums;
for (const auto &dimensionElem : dimensionSizes) {
- dimensionEnums[dimensionElem.first] =
- builder.defineDimension(dimensionElem.first,
- dimensionElem.second);
+ dimensionEnums[dimensionElem.first] = builder.defineDimension(dimensionElem.first, dimensionElem.second);
}
for (const auto &cell : cells) {
for (const auto &addressElem : cell.first) {
@@ -71,5 +62,4 @@ TensorFactory::createDense(const DenseTensorCells &cells)
}
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/tensor_factory.h b/eval/src/vespa/eval/tensor/tensor_factory.h
index 5fe31afc4dd..5364c28c8ff 100644
--- a/eval/src/vespa/eval/tensor/tensor_factory.h
+++ b/eval/src/vespa/eval/tensor/tensor_factory.h
@@ -4,8 +4,7 @@
#include "types.h"
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
class Tensor;
@@ -20,11 +19,9 @@ public:
static std::unique_ptr<Tensor>
create(const TensorCells &cells, TensorBuilder &builder);
static std::unique_ptr<Tensor>
- create(const TensorCells &cells, const TensorDimensions &dimensions,
- TensorBuilder &builder);
+ create(const TensorCells &cells, const TensorDimensions &dimensions, TensorBuilder &builder);
static std::unique_ptr<Tensor>
createDense(const DenseTensorCells &cells);
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/tensor/tensor_mapper.cpp b/eval/src/vespa/eval/tensor/tensor_mapper.cpp
index c91237e4994..dbf1965d441 100644
--- a/eval/src/vespa/eval/tensor/tensor_mapper.cpp
+++ b/eval/src/vespa/eval/tensor/tensor_mapper.cpp
@@ -53,9 +53,7 @@ SparseTensorMapper(const ValueType &type)
}
template <class TensorT>
-SparseTensorMapper<TensorT>::~SparseTensorMapper()
-{
-}
+SparseTensorMapper<TensorT>::~SparseTensorMapper() = default;
template <class TensorT>
std::unique_ptr<Tensor>
diff --git a/eval/src/vespa/eval/tensor/tensor_mapper.h b/eval/src/vespa/eval/tensor/tensor_mapper.h
index 99994bd15e8..95c6cce8fc6 100644
--- a/eval/src/vespa/eval/tensor/tensor_mapper.h
+++ b/eval/src/vespa/eval/tensor/tensor_mapper.h
@@ -4,8 +4,7 @@
#include <vespa/eval/eval/value_type.h>
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
class Tensor;
@@ -42,5 +41,4 @@ public:
};
-} // namespace vespalib::tensor
-} // namespace vespalib
+}
diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
index c4c3021d607..0612cee040c 100644
--- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
+++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
@@ -14,6 +14,24 @@ import static com.yahoo.vespa.flags.FetchVector.Dimension.HOSTNAME;
import static com.yahoo.vespa.flags.FetchVector.Dimension.NODE_TYPE;
/**
+ * Definitions of feature flags.
+ *
+ * <p>To use feature flags, define the flag in this class as an "unbound" flag, e.g. {@link UnboundBooleanFlag}
+ * or {@link UnboundStringFlag}. At the location you want to get the value of the flag, you need the following:</p>
+ *
+ * <ol>
+ * <li>The unbound flag</li>
+ * <li>A {@link FlagSource}. The flag source is typically available as an injectible component. Binding
+ * an unbound flag to a flag source produces a (bound) flag, e.g. {@link BooleanFlag} and {@link StringFlag}.</li>
+ * <li>If you would like your flag value to be dependent on e.g. the application ID, then 1. you should
+ * declare this in the unbound flag definition in this file (referring to
+ * {@link FetchVector.Dimension#APPLICATION_ID}), and 2. specify the application ID when retrieving the value, e.g.
+ * {@link BooleanFlag#with(FetchVector.Dimension, String)}. See {@link FetchVector} for more info.</li>
+ * </ol>
+ *
+ * <p>Once the code is in place, you can override the flag value. This depends on the flag source, but typically
+ * there is a REST API for updating the flags in the config server, which is the root of all flag sources in the zone.</p>
+ *
* @author hakonhall
*/
public class Flags {
@@ -137,6 +155,12 @@ public class Flags {
"Takes effect at redeployment",
APPLICATION_ID);
+ public static final UnboundBooleanFlag REDIRECT_LEGACY_DNS_NAMES = defineFeatureFlag(
+ "redirect-legacy-dns", false,
+ "Redirect legacy DNS names to the main DNS name",
+ "Takes effect on deployment through controller",
+ APPLICATION_ID);
+
/** WARNING: public for testing: All flags should be defined in {@link Flags}. */
public static UnboundBooleanFlag defineFeatureFlag(String flagId, boolean defaultValue, String description,
String modificationEffect, FetchVector.Dimension... dimensions) {
diff --git a/jdisc_core/src/main/java/com/yahoo/jdisc/Request.java b/jdisc_core/src/main/java/com/yahoo/jdisc/Request.java
index 061d803b978..466e74202c1 100644
--- a/jdisc_core/src/main/java/com/yahoo/jdisc/Request.java
+++ b/jdisc_core/src/main/java/com/yahoo/jdisc/Request.java
@@ -12,6 +12,7 @@ import com.yahoo.jdisc.service.CurrentContainer;
import com.yahoo.jdisc.service.ServerProvider;
import java.net.URI;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
@@ -35,7 +36,7 @@ import java.util.concurrent.TimeUnit;
*/
public class Request extends AbstractResource {
- private final Map<String, Object> context = new HashMap<>();
+ private final Map<String, Object> context = Collections.synchronizedMap(new HashMap<>());
private final HeaderFields headers = new HeaderFields();
private final Container container;
private final Request parent;
@@ -205,10 +206,6 @@ public class Request extends AbstractResource {
* <p>Returns the named application context objects. This data is not intended for network transport, rather they
* are intended for passing shared data between components of an Application.</p>
*
- * <p>Modifying the context map is a thread-unsafe operation -- any changes made after calling {@link
- * #connect(ResponseHandler)} might never become visible to other threads, and might throw
- * ConcurrentModificationExceptions in other threads.</p>
- *
* @return The context map.
*/
public Map<String, Object> context() {
diff --git a/jdisc_core/src/main/java/com/yahoo/jdisc/application/UriPattern.java b/jdisc_core/src/main/java/com/yahoo/jdisc/application/UriPattern.java
index 0e6e5d28260..350d8170987 100644
--- a/jdisc_core/src/main/java/com/yahoo/jdisc/application/UriPattern.java
+++ b/jdisc_core/src/main/java/com/yahoo/jdisc/application/UriPattern.java
@@ -65,7 +65,7 @@ public class UriPattern implements Comparable<UriPattern> {
if (!matcher.find()) {
throw new IllegalArgumentException(uri);
}
- scheme = GlobPattern.compile(resolvePatternComponent(matcher.group(1)));
+ scheme = GlobPattern.compile(normalizeScheme(resolvePatternComponent(matcher.group(1))));
host = GlobPattern.compile(resolvePatternComponent(matcher.group(2)));
port = resolvePortPattern(matcher.group(4));
path = GlobPattern.compile(resolvePatternComponent(matcher.group(7)));
@@ -91,7 +91,7 @@ public class UriPattern implements Comparable<UriPattern> {
return null;
}
// Match scheme before host because it has a higher chance of differing (e.g. http versus https)
- GlobPattern.Match schemeMatch = scheme.match(resolveUriComponent(uri.getScheme()));
+ GlobPattern.Match schemeMatch = scheme.match(normalizeScheme(resolveUriComponent(uri.getScheme())));
if (schemeMatch == null) {
return null;
}
@@ -172,6 +172,11 @@ public class UriPattern implements Comparable<UriPattern> {
}
}
+ private static String normalizeScheme(String scheme) {
+ if (scheme.equals("https")) return "http"; // handle 'https' in bindings and uris as 'http'
+ return scheme;
+ }
+
/**
* <p>This class holds the result of a {@link UriPattern#match(URI)} operation. It contains methods to inspect the
* groups captured during matching, where a <em>group</em> is defined as a sequence of characters matches by a
diff --git a/jdisc_core/src/test/java/com/yahoo/jdisc/application/UriPatternTestCase.java b/jdisc_core/src/test/java/com/yahoo/jdisc/application/UriPatternTestCase.java
index c91a7134c3a..d2499bbf369 100644
--- a/jdisc_core/src/test/java/com/yahoo/jdisc/application/UriPatternTestCase.java
+++ b/jdisc_core/src/test/java/com/yahoo/jdisc/application/UriPatternTestCase.java
@@ -295,6 +295,15 @@ public class UriPatternTestCase {
assertMatch(httpsPattern, "https://host/path", NO_GROUPS);
}
+ @Test
+ public void requireThatHttpsSchemeIsHandledAsHttp() {
+ UriPattern httpPattern = new UriPattern("http://host:80/path");
+ assertMatch(httpPattern, "https://host:80/path", NO_GROUPS);
+
+ UriPattern httpsPattern = new UriPattern("https://host:443/path");
+ assertMatch(httpsPattern, "http://host:443/path", NO_GROUPS);
+ }
+
private static void assertIllegalPattern(String uri) {
try {
new UriPattern(uri);
diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java
index 1e92fbef967..4239d2120cf 100644
--- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java
+++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java
@@ -16,10 +16,9 @@ import javax.servlet.AsyncListener;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
-
import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
+import java.util.ArrayList;
+import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
@@ -40,19 +39,25 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G
GET, PATCH, POST, PUT, DELETE, OPTIONS, HEAD, OTHER
}
+ public enum HttpScheme {
+ HTTP, HTTPS, OTHER
+ }
+
private static final String[] HTTP_RESPONSE_GROUPS = { Metrics.RESPONSES_1XX, Metrics.RESPONSES_2XX, Metrics.RESPONSES_3XX,
Metrics.RESPONSES_4XX, Metrics.RESPONSES_5XX, Metrics.RESPONSES_401, Metrics.RESPONSES_403};
private final AtomicLong inFlight = new AtomicLong();
- private final LongAdder statistics[][];
+ private final LongAdder statistics[][][];
public HttpResponseStatisticsCollector() {
super();
- statistics = new LongAdder[HttpMethod.values().length][];
- for (int method = 0; method < statistics.length; method++) {
- statistics[method] = new LongAdder[HTTP_RESPONSE_GROUPS.length];
- for (int group = 0; group < HTTP_RESPONSE_GROUPS.length; group++) {
- statistics[method][group] = new LongAdder();
+ statistics = new LongAdder[HttpScheme.values().length][HttpMethod.values().length][];
+ for (int scheme = 0; scheme < HttpScheme.values().length; ++scheme) {
+ for (int method = 0; method < HttpMethod.values().length; method++) {
+ statistics[scheme][method] = new LongAdder[HTTP_RESPONSE_GROUPS.length];
+ for (int group = 0; group < HTTP_RESPONSE_GROUPS.length; group++) {
+ statistics[scheme][method][group] = new LongAdder();
+ }
}
}
}
@@ -110,10 +115,11 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G
private void observeEndOfRequest(Request request, HttpServletResponse flushableResponse) throws IOException {
int group = groupIndex(request);
if (group >= 0) {
+ HttpScheme scheme = getScheme(request);
HttpMethod method = getMethod(request);
- statistics[method.ordinal()][group].increment();
+ statistics[scheme.ordinal()][method.ordinal()][group].increment();
if (group == 5 || group == 6) { // if 401/403, also increment 4xx
- statistics[method.ordinal()][3].increment();
+ statistics[scheme.ordinal()][method.ordinal()][3].increment();
}
}
@@ -146,6 +152,17 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G
}
}
+ private HttpScheme getScheme(Request request) {
+ switch (request.getScheme()) {
+ case "http":
+ return HttpScheme.HTTP;
+ case "https":
+ return HttpScheme.HTTPS;
+ default:
+ return HttpScheme.OTHER;
+ }
+ }
+
private HttpMethod getMethod(Request request) {
switch (request.getMethod()) {
case "GET":
@@ -167,17 +184,18 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G
}
}
- public Map<String, Map<String, Long>> takeStatisticsByMethod() {
- Map<String, Map<String, Long>> ret = new HashMap<>();
-
- for (HttpMethod method : HttpMethod.values()) {
- int methodIndex = method.ordinal();
- Map<String, Long> methodStats = new HashMap<>();
- ret.put(method.toString(), methodStats);
-
- for (int group = 0; group < HTTP_RESPONSE_GROUPS.length; group++) {
- long value = statistics[methodIndex][group].sumThenReset();
- methodStats.put(HTTP_RESPONSE_GROUPS[group], value);
+ public List<StatisticsEntry> takeStatistics() {
+ var ret = new ArrayList<StatisticsEntry>();
+ for (HttpScheme scheme : HttpScheme.values()) {
+ int schemeIndex = scheme.ordinal();
+ for (HttpMethod method : HttpMethod.values()) {
+ int methodIndex = method.ordinal();
+ for (int group = 0; group < HTTP_RESPONSE_GROUPS.length; group++) {
+ long value = statistics[schemeIndex][methodIndex][group].sumThenReset();
+ if (value > 0) {
+ ret.add(new StatisticsEntry(scheme.name().toLowerCase(), method.name(), HTTP_RESPONSE_GROUPS[group], value));
+ }
+ }
}
}
return ret;
@@ -216,4 +234,19 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G
FutureCallback futureCallback = shutdown.get();
return futureCallback != null && futureCallback.isDone();
}
+
+ public static class StatisticsEntry {
+ public final String scheme;
+ public final String method;
+ public final String name;
+ public final long value;
+
+
+ public StatisticsEntry(String scheme, String method, String name, long value) {
+ this.scheme = scheme;
+ this.method = method;
+ this.name = name;
+ this.value = value;
+ }
+ }
}
diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscServerConnector.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscServerConnector.java
index 6b371473a57..556d80d3772 100644
--- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscServerConnector.java
+++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscServerConnector.java
@@ -17,6 +17,7 @@ import java.net.SocketException;
import java.nio.channels.ServerSocketChannel;
import java.util.HashMap;
import java.util.Map;
+import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -28,7 +29,7 @@ class JDiscServerConnector extends ServerConnector {
public static final String REQUEST_ATTRIBUTE = JDiscServerConnector.class.getName();
private final static Logger log = Logger.getLogger(JDiscServerConnector.class.getName());
private final Metric.Context metricCtx;
- private final Map<String, Metric.Context> requestMetricContextCache = new ConcurrentHashMap<>();
+ private final Map<RequestDimensions, Metric.Context> requestMetricContextCache = new ConcurrentHashMap<>();
private final ServerConnectionStatistics statistics;
private final boolean tcpKeepAlive;
private final boolean tcpNoDelay;
@@ -124,9 +125,12 @@ class JDiscServerConnector extends ServerConnector {
public Metric.Context getRequestMetricContext(HttpServletRequest request) {
String method = request.getMethod();
- return requestMetricContextCache.computeIfAbsent(method, ignored -> {
+ String scheme = request.getScheme();
+ var requestDimensions = new RequestDimensions(method, scheme);
+ return requestMetricContextCache.computeIfAbsent(requestDimensions, ignored -> {
Map<String, Object> dimensions = createConnectorDimensions(listenPort, connectorName);
dimensions.put(JettyHttpServer.Metrics.METHOD_DIMENSION, method);
+ dimensions.put(JettyHttpServer.Metrics.SCHEME_DIMENSION, scheme);
return metric.createContext(dimensions);
});
}
@@ -142,4 +146,27 @@ class JDiscServerConnector extends ServerConnector {
return props;
}
+ private static class RequestDimensions {
+ final String method;
+ final String scheme;
+
+ RequestDimensions(String method, String scheme) {
+ this.method = method;
+ this.scheme = scheme;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ RequestDimensions that = (RequestDimensions) o;
+ return Objects.equals(method, that.method) && Objects.equals(scheme, that.scheme);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(method, scheme);
+ }
+ }
+
}
diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java
index 0dbc5f59f67..30a1b1d885c 100644
--- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java
+++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java
@@ -8,7 +8,6 @@ import com.yahoo.component.ComponentId;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.container.logging.AccessLog;
import com.yahoo.jdisc.Metric;
-import com.yahoo.jdisc.Metric.Context;
import com.yahoo.jdisc.application.OsgiFramework;
import com.yahoo.jdisc.http.ServerConfig;
import com.yahoo.jdisc.http.ServletPathsConfig;
@@ -71,6 +70,7 @@ public class JettyHttpServer extends AbstractServerProvider {
String NAME_DIMENSION = "serverName";
String PORT_DIMENSION = "serverPort";
String METHOD_DIMENSION = "httpMethod";
+ String SCHEME_DIMENSION = "scheme";
String NUM_OPEN_CONNECTIONS = "serverNumOpenConnections";
String NUM_CONNECTIONS_OPEN_MAX = "serverConnectionsOpenMax";
@@ -357,13 +357,12 @@ public class JettyHttpServer extends AbstractServerProvider {
}
private void addResponseMetrics(HttpResponseStatisticsCollector statisticsCollector) {
- Map<String, Map<String, Long>> statistics = statisticsCollector.takeStatisticsByMethod();
- statistics.forEach((httpMethod, statsByResponseType) -> {
+ for (var metricEntry : statisticsCollector.takeStatistics()) {
Map<String, Object> dimensions = new HashMap<>();
- dimensions.put(Metrics.METHOD_DIMENSION, httpMethod);
- Context ctx = metric.createContext(dimensions);
- statsByResponseType.forEach((group, value) -> metric.add(group, value, ctx));
- });
+ dimensions.put(Metrics.METHOD_DIMENSION, metricEntry.method);
+ dimensions.put(Metrics.SCHEME_DIMENSION, metricEntry.scheme);
+ metric.add(metricEntry.name, metricEntry.value, metric.createContext(dimensions));
+ }
}
private void setConnectorMetrics(JDiscServerConnector connector) {
diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java
index 3c23a2b0937..df2308f6dd0 100644
--- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java
+++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.jdisc.http.server.jetty;
+import com.yahoo.jdisc.http.server.jetty.HttpResponseStatisticsCollector.StatisticsEntry;
import com.yahoo.jdisc.http.server.jetty.JettyHttpServer.Metrics;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpURI;
@@ -22,10 +23,9 @@ import org.testng.annotations.Test;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
-
import java.io.IOException;
import java.nio.ByteBuffer;
-import java.util.Map;
+import java.util.List;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
@@ -40,55 +40,62 @@ public class HttpResponseStatisticsCollectorTest {
@Test
public void statistics_are_aggregated_by_category() throws Exception {
- testRequest(300, "GET");
- testRequest(301, "GET");
- testRequest(200, "GET");
+ testRequest("http", 300, "GET");
+ testRequest("http", 301, "GET");
+ testRequest("http", 200, "GET");
- Map<String, Map<String, Long>> stats = collector.takeStatisticsByMethod();
- assertThat(stats.get("GET").get(Metrics.RESPONSES_2XX), equalTo(1L));
- assertThat(stats.get("GET").get(Metrics.RESPONSES_3XX), equalTo(2L));
+ var stats = collector.takeStatistics();
+ assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_2XX, 1L);
+ assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_3XX, 2L);
}
@Test
- public void statistics_are_grouped_by_http_method() throws Exception {
- testRequest(200, "GET");
- testRequest(200, "PUT");
- testRequest(200, "POST");
- testRequest(200, "POST");
- testRequest(404, "GET");
-
- Map<String, Map<String, Long>> stats = collector.takeStatisticsByMethod();
- assertThat(stats.get("GET").get(Metrics.RESPONSES_2XX), equalTo(1L));
- assertThat(stats.get("GET").get(Metrics.RESPONSES_4XX), equalTo(1L));
- assertThat(stats.get("PUT").get(Metrics.RESPONSES_2XX), equalTo(1L));
- assertThat(stats.get("POST").get(Metrics.RESPONSES_2XX), equalTo(2L));
+ public void statistics_are_grouped_by_http_method_and_scheme() throws Exception {
+ testRequest("http", 200, "GET");
+ testRequest("http", 200, "PUT");
+ testRequest("http", 200, "POST");
+ testRequest("http", 200, "POST");
+ testRequest("http", 404, "GET");
+ testRequest("https", 404, "GET");
+ testRequest("https", 200, "POST");
+ testRequest("https", 200, "POST");
+ testRequest("https", 200, "POST");
+ testRequest("https", 200, "POST");
+
+ var stats = collector.takeStatistics();
+ assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_2XX, 1L);
+ assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_4XX, 1L);
+ assertStatisticsEntryPresent(stats, "http", "PUT", Metrics.RESPONSES_2XX, 1L);
+ assertStatisticsEntryPresent(stats, "http", "POST", Metrics.RESPONSES_2XX, 2L);
+ assertStatisticsEntryPresent(stats, "https", "GET", Metrics.RESPONSES_4XX, 1L);
+ assertStatisticsEntryPresent(stats, "https", "POST", Metrics.RESPONSES_2XX, 4L);
}
@Test
public void statistics_include_grouped_and_single_statuscodes() throws Exception {
- testRequest(401, "GET");
- testRequest(404, "GET");
- testRequest(403, "GET");
+ testRequest("http", 401, "GET");
+ testRequest("http", 404, "GET");
+ testRequest("http", 403, "GET");
- Map<String, Map<String, Long>> stats = collector.takeStatisticsByMethod();
- assertThat(stats.get("GET").get(Metrics.RESPONSES_4XX), equalTo(3L));
- assertThat(stats.get("GET").get(Metrics.RESPONSES_401), equalTo(1L));
- assertThat(stats.get("GET").get(Metrics.RESPONSES_403), equalTo(1L));
+ var stats = collector.takeStatistics();
+ assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_4XX, 3L);
+ assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_401, 1L);
+ assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_403, 1L);
}
@Test
public void retrieving_statistics_resets_the_counters() throws Exception {
- testRequest(200, "GET");
- testRequest(200, "GET");
+ testRequest("http", 200, "GET");
+ testRequest("http", 200, "GET");
- Map<String, Map<String, Long>> stats = collector.takeStatisticsByMethod();
- assertThat(stats.get("GET").get(Metrics.RESPONSES_2XX), equalTo(2L));
+ var stats = collector.takeStatistics();
+ assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_2XX, 2L);
- testRequest(200, "GET");
+ testRequest("http", 200, "GET");
- stats = collector.takeStatisticsByMethod();
- assertThat(stats.get("GET").get(Metrics.RESPONSES_2XX), equalTo(1L));
+ stats = collector.takeStatistics();
+ assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_2XX, 1L);
}
@BeforeTest
@@ -116,9 +123,9 @@ public class HttpResponseStatisticsCollectorTest {
server.start();
}
- private Request testRequest(int responseCode, String httpMethod) throws Exception {
+ private Request testRequest(String scheme, int responseCode, String httpMethod) throws Exception {
HttpChannel channel = new HttpChannel(connector, new HttpConfiguration(), null, new DummyTransport());
- MetaData.Request metaData = new MetaData.Request(httpMethod, new HttpURI("http://foo/bar"), HttpVersion.HTTP_1_1, new HttpFields());
+ MetaData.Request metaData = new MetaData.Request(httpMethod, new HttpURI(scheme + "://foo/bar"), HttpVersion.HTTP_1_1, new HttpFields());
Request req = channel.getRequest();
req.setMetaData(metaData);
@@ -127,6 +134,15 @@ public class HttpResponseStatisticsCollectorTest {
return req;
}
+ private static void assertStatisticsEntryPresent(List<StatisticsEntry> result, String scheme, String method, String name, long expectedValue) {
+ long value = result.stream()
+ .filter(entry -> entry.method.equals(method) && entry.scheme.equals(scheme) && entry.name.equals(name))
+ .mapToLong(entry -> entry.value)
+ .findAny()
+ .orElseThrow(() -> new AssertionError(String.format("Not matching entry in result (scheme=%s, method=%s, name=%s)", scheme, method, name)));
+ assertThat(value, equalTo(expectedValue));
+ }
+
private final class DummyTransport implements HttpTransport {
@Override
public void send(Response info, boolean head, ByteBuffer content, boolean lastContent, Callback callback) {
diff --git a/jrt/pom.xml b/jrt/pom.xml
index 5208c0417cc..e9383654e30 100644
--- a/jrt/pom.xml
+++ b/jrt/pom.xml
@@ -34,6 +34,16 @@
<artifactId>security-utils</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
+ <exclusions>
+ <exclusion> <!-- not needed -->
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>httpclient</artifactId>
+ </exclusion>
+ <exclusion> <!-- not needed -->
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>httpcore</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency> <!-- required due to bug in maven dependency resolving - bouncycastle is compile scope in security-utils, yet it is not part of test scope here -->
<groupId>org.bouncycastle</groupId>
diff --git a/logd/src/logd/legacy_forwarder.cpp b/logd/src/logd/legacy_forwarder.cpp
index c0f74d205e7..851e4458f77 100644
--- a/logd/src/logd/legacy_forwarder.cpp
+++ b/logd/src/logd/legacy_forwarder.cpp
@@ -11,6 +11,7 @@
#include <vespa/vespalib/util/stringfmt.h>
#include <fcntl.h>
#include <unistd.h>
+#include <sstream>
#include <vespa/log/log.h>
LOG_SETUP("");
@@ -126,12 +127,12 @@ void
LegacyForwarder::forwardLine(std::string_view line)
{
assert(_logserver_fd >= 0);
- assert (line.size() > 0);
assert (line.size() < 1024*1024);
- assert (line[line.size() - 1] == '\n');
if (parseLine(line)) {
- forwardText(line.data(), line.size());
+ std::ostringstream line_copy;
+ line_copy << line << std::endl;
+ forwardText(line_copy.str().data(), line_copy.str().size());
}
}
diff --git a/logd/src/logd/watcher.cpp b/logd/src/logd/watcher.cpp
index a92ad456e9f..fca9cd648bb 100644
--- a/logd/src/logd/watcher.cpp
+++ b/logd/src/logd/watcher.cpp
@@ -222,7 +222,7 @@ Watcher::watchfile()
}
while (nnl != nullptr && elapsed(tickStart) < 1) {
++nnl;
- _forwarder.forwardLine(std::string_view(l, (nnl - l)));
+ _forwarder.forwardLine(std::string_view(l, (nnl - l) - 1));
ssize_t wsize = nnl - l;
offset += wsize;
l = nnl;
diff --git a/logd/src/tests/legacy_forwarder/legacy_forwarder_test.cpp b/logd/src/tests/legacy_forwarder/legacy_forwarder_test.cpp
index d3339894819..67d47a49384 100644
--- a/logd/src/tests/legacy_forwarder/legacy_forwarder_test.cpp
+++ b/logd/src/tests/legacy_forwarder/legacy_forwarder_test.cpp
@@ -40,7 +40,7 @@ struct ForwardFixture {
timer.SetNow();
std::stringstream ss;
ss << std::fixed << timer.Secs();
- ss << "\texample.yahoo.com\t7518/34779\tlogd\tlogdemon\tevent\tstarted/1 name=\"logdemon\"\n";
+ ss << "\texample.yahoo.com\t7518/34779\tlogd\tlogdemon\tevent\tstarted/1 name=\"logdemon\"";
return ss.str();
}
@@ -50,7 +50,7 @@ struct ForwardFixture {
int rfd = open(fname.c_str(), O_RDONLY);
char *buffer[2048];
ssize_t bytes = read(rfd, buffer, 2048);
- ssize_t expected = doForward ? logLine.length() : 0;
+ ssize_t expected = doForward ? logLine.length() + 1 : 0;
EXPECT_EQUAL(expected, bytes);
close(rfd);
}
diff --git a/logd/src/tests/watcher/watcher_test.cpp b/logd/src/tests/watcher/watcher_test.cpp
index c2b379cc1a4..fffaac17058 100644
--- a/logd/src/tests/watcher/watcher_test.cpp
+++ b/logd/src/tests/watcher/watcher_test.cpp
@@ -71,8 +71,7 @@ struct DummyForwarder : public Forwarder {
void sendMode() override { ++sendModeCount; }
void forwardLine(std::string_view log_line) override {
std::lock_guard guard(lock);
- assert(log_line.size() > 0u);
- lines.emplace_back(log_line.substr(0, log_line.size() - 1));
+ lines.emplace_back(log_line);
cond.notify_all();
}
void flush() override { }
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
index c4acfeb3235..9c8f6238731 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
@@ -29,9 +29,17 @@ public class OrderedTensorType {
private final long[] innerSizesVespa;
private final int[] dimensionMap;
- private OrderedTensorType(List<TensorType.Dimension> dimensions) {
+ private OrderedTensorType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) {
this.dimensions = Collections.unmodifiableList(dimensions);
- this.type = new TensorType.Builder(dimensions).build();
+ this.type = new TensorType.Builder(valueType, dimensions).build();
+ this.innerSizesOriginal = new long[dimensions.size()];
+ this.innerSizesVespa = new long[dimensions.size()];
+ this.dimensionMap = createDimensionMap();
+ }
+
+ private OrderedTensorType(TensorType type) {
+ this.dimensions = type.dimensions();
+ this.type = type;
this.innerSizesOriginal = new long[dimensions.size()];
this.innerSizesVespa = new long[dimensions.size()];
this.dimensionMap = createDimensionMap();
@@ -136,11 +144,11 @@ public class OrderedTensorType {
renamedDimensions.add(TensorType.Dimension.mapped(newName.get()));
}
}
- return new OrderedTensorType(renamedDimensions);
+ return new OrderedTensorType(type.valueType(), renamedDimensions);
}
public OrderedTensorType rename(String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(type.valueType());
for (int i = 0; i < dimensions.size(); ++ i) {
String dimensionName = dimensionPrefix + i;
Optional<Long> dimSize = dimensions.get(i).size();
@@ -154,7 +162,7 @@ public class OrderedTensorType {
}
public static OrderedTensorType standardType(OrderedTensorType type) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(type.type().valueType());
for (int i = 0; i < type.dimensions().size(); ++ i) {
TensorType.Dimension dim = type.dimensions().get(i);
String dimensionName = "d" + i;
@@ -193,18 +201,18 @@ public class OrderedTensorType {
* where dimensions are listed in the order of this rather than the natural order of their names.
*/
public static OrderedTensorType fromSpec(String typeSpec) {
- return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
+ return new OrderedTensorType(TensorType.fromSpec(typeSpec));
}
- public static OrderedTensorType fromDimensionList(List<Long> dims) {
- return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ...
+ public static OrderedTensorType fromDimensionList(TensorType.Value valueType, List<Long> dimensions) {
+ return fromDimensionList(valueType, dimensions, "d"); // standard naming convention: d0, d1, ...
}
- private static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- for (int i = 0; i < dims.size(); ++ i) {
+ private static OrderedTensorType fromDimensionList(TensorType.Value valueType, List<Long> dimensions, String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(valueType);
+ for (int i = 0; i < dimensions.size(); ++ i) {
String dimensionName = dimensionPrefix + i;
- Long dimSize = dims.get(i);
+ Long dimSize = dimensions.get(i);
if (dimSize >= 0) {
builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
} else {
@@ -216,9 +224,15 @@ public class OrderedTensorType {
public static class Builder {
+ private final TensorType.Value valueType;
private final List<TensorType.Dimension> dimensions;
public Builder() {
+ this(TensorType.Value.DOUBLE);
+ }
+
+ public Builder(TensorType.Value valueType) {
+ this.valueType = valueType;
this.dimensions = new ArrayList<>();
}
@@ -228,7 +242,7 @@ public class OrderedTensorType {
}
public OrderedTensorType build() {
- return new OrderedTensorType(dimensions);
+ return new OrderedTensorType(valueType, dimensions);
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index dd2add973e4..a469e666d93 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -105,8 +105,8 @@ class GraphImporter {
if (isArgumentTensor(name, onnxGraph)) {
Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph);
if (valueInfoProto == null)
- throw new IllegalArgumentException("Could not find argument tensor: " + name);
- OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType());
+ throw new IllegalArgumentException("Could not find argument tensor '" + name + "'");
+ OrderedTensorType type = TypeConverter.typeFrom(valueInfoProto.getType());
operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type);
intermediateGraph.inputs(intermediateGraph.defaultSignature())
@@ -114,7 +114,7 @@ class GraphImporter {
} else if (isConstantTensor(name, onnxGraph)) {
Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
- OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList());
+ OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto);
operation = new Constant(intermediateGraph.name(), name, defaultType);
operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index f251a14213b..29d600fa7c6 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -30,13 +30,10 @@ class TypeConverter {
}
}
- static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
- return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
- }
-
- private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
+ static OrderedTensorType typeFrom(Onnx.TypeProto type) {
+ String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(type.getTensorType().getElemType()));
for (int i = 0; i < shape.getDimCount(); ++ i) {
String dimensionName = dimensionPrefix + i;
Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
@@ -49,4 +46,28 @@ class TypeConverter {
return builder.build();
}
+ static OrderedTensorType typeFrom(Onnx.TensorProto tensor) {
+ return OrderedTensorType.fromDimensionList(toValueType(tensor.getDataType()),
+ tensor.getDimsList());
+ }
+
+ private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT8: return TensorType.Value.FLOAT;
+ case INT16: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.DOUBLE;
+ case UINT8: return TensorType.Value.FLOAT;
+ case UINT16: return TensorType.Value.FLOAT;
+ case UINT32: return TensorType.Value.FLOAT;
+ case UINT64: return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
index 1a564661ccb..7ae50a0549d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
@@ -21,20 +21,15 @@ public class ConcatV2 extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
- return null;
- }
+ if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null;
IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
- if (!concatDimOp.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "concat dimension must be a constant.");
- }
+ if ( ! concatDimOp.getConstantValue().isPresent())
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a constant.");
+
Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor();
- if (concatDimTensor.type().rank() != 0) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "concat dimension must be a scalar.");
- }
+ if (concatDimTensor.type().rank() != 0)
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a scalar.");
OrderedTensorType aType = inputs.get(0).type().get();
concatDimensionIndex = (int)concatDimTensor.asDouble();
@@ -42,10 +37,9 @@ public class ConcatV2 extends IntermediateOperation {
for (int i = 1; i < inputs.size() - 1; ++i) {
OrderedTensorType bType = inputs.get(i).type().get();
- if (bType.rank() != aType.rank()) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "inputs must have save rank.");
- }
+ if (bType.rank() != aType.rank())
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": Inputs must have the same rank.");
+
for (int j = 0; j < aType.rank(); ++j) {
long dimSizeA = aType.dimensions().get(j).size().orElse(-1L);
long dimSizeB = bType.dimensions().get(j).size().orElse(-1L);
@@ -58,7 +52,7 @@ public class ConcatV2 extends IntermediateOperation {
}
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
int dimensionIndex = 0;
for (TensorType.Dimension dimension : aType.dimensions()) {
if (dimensionIndex == concatDimensionIndex) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index 8ae6d81b8d4..c64b9ded601 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -27,20 +27,15 @@ public class ExpandDims extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
IntermediateOperation axisOperation = inputs().get(1);
if (!axisOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ExpandDims in " + name + ": " +
- "axis must be a constant.");
+ throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant.");
}
Tensor axis = axisOperation.getConstantValue().get().asTensor();
- if (axis.type().rank() != 0) {
- throw new IllegalArgumentException("ExpandDims in " + name + ": " +
- "axis argument must be a scalar.");
- }
+ if (axis.type().rank() != 0)
+ throw new IllegalArgumentException("ExpandDims in " + name + ": Axis argument must be a scalar.");
OrderedTensorType inputType = inputs.get(0).type().get();
int dimensionToInsert = (int)axis.asDouble();
@@ -48,7 +43,7 @@ public class ExpandDims extends IntermediateOperation {
dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
expandDimensions = new ArrayList<>();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : inputType.dimensions()) {
@@ -66,12 +61,10 @@ public class ExpandDims extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
+ if ( ! allInputFunctionsPresent(2)) return null;
// multiply with a generated tensor created from the reduced dimensions
- TensorType.Builder typeBuilder = new TensorType.Builder();
+ TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
for (String name : expandDimensions) {
typeBuilder.indexed(name, 1);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 3b77f9527ca..0ee54f839bc 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -9,6 +9,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
@@ -17,6 +18,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
+import java.util.stream.Collectors;
/**
* Wraps an imported operation node and produces the respective Vespa tensor
@@ -161,6 +163,19 @@ public abstract class IntermediateOperation {
}
/**
+ * Returns the largest value type among the input value types.
+ * This should only be called after it has been verified that input types are available.
+ *
+ * @throws IllegalArgumentException if a type cannot be uniquely determined
+ * @throws RuntimeException if called when input types are not available
+ */
+ TensorType.Value resultValueType() {
+ return TensorType.Value.largestOf(inputs.stream()
+ .map(input -> input.type().get().type().valueType())
+ .collect(Collectors.toList()));
+ }
+
+ /**
* A method signature input and output has the form name:index.
* This returns the name part without the index.
*/
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index fed95e13bb7..c2d75153586 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -22,13 +22,12 @@ public class Join extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
OrderedTensorType a = largestInput().type().get();
OrderedTensorType b = smallestInput().type().get();
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
int sizeDifference = a.rank() - b.rank();
for (int i = 0; i < a.rank(); ++i) {
TensorType.Dimension aDim = a.dimensions().get(i);
@@ -52,12 +51,8 @@ public class Join extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+ if ( ! allInputFunctionsPresent(2)) return null;
IntermediateOperation a = largestInput();
IntermediateOperation b = smallestInput();
@@ -92,9 +87,8 @@ public class Join extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
+ if ( ! allInputTypesPresent(2)) return;
+
OrderedTensorType a = largestInput().type().get();
OrderedTensorType b = smallestInput().type().get();
int sizeDifference = a.rank() - b.rank();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 1dbfd6e40dc..9a76662529d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -17,10 +17,9 @@ public class MatMul extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ if ( ! allInputTypesPresent(2)) return null;
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
return typeBuilder.build();
@@ -28,9 +27,8 @@ public class MatMul extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
OrderedTensorType aType = inputs.get(0).type().get();
OrderedTensorType bType = inputs.get(1).type().get();
if (aType.type().rank() < 2 || bType.type().rank() < 2)
@@ -48,9 +46,8 @@ public class MatMul extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
+ if ( ! allInputTypesPresent(2)) return;
+
List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
@@ -69,4 +66,5 @@ public class MatMul extends IntermediateOperation {
renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
}
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
index 4be220db9d5..d8e9950c61f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
@@ -32,13 +32,11 @@ public class Mean extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
IntermediateOperation reductionIndices = inputs.get(1);
- if (!reductionIndices.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Mean in " + name + ": " +
- "reduction indices must be a constant.");
+ if ( ! reductionIndices.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("Mean in " + name + ": Reduction indices must be a constant.");
}
Tensor indices = reductionIndices.getConstantValue().get().asTensor();
reduceDimensions = new ArrayList<>();
@@ -59,14 +57,14 @@ public class Mean extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
+
TensorFunction inputFunction = inputs.get(0).function().get();
TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
if (shouldKeepDimensions()) {
// multiply with a generated tensor created from the reduced dimensions
- TensorType.Builder typeBuilder = new TensorType.Builder();
+ TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
for (String name : reduceDimensions) {
typeBuilder.indexed(name, 1);
}
@@ -99,9 +97,9 @@ public class Mean extends IntermediateOperation {
}
private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
- if (!reduceDimensions.contains(dimension.name())) {
+ if ( ! reduceDimensions.contains(dimension.name())) {
builder.add(dimension);
} else if (keepDimensions) {
builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index 18f3cc1cc39..4a0fe236c9f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -32,18 +32,16 @@ public class Reshape extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
IntermediateOperation newShape = inputs.get(1);
- if (!newShape.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Reshape in " + name + ": " +
- "shape input must be a constant.");
- }
+ if ( ! newShape.getConstantValue().isPresent())
+ throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant.");
+
Tensor shape = newShape.getConstantValue().get().asTensor();
OrderedTensorType inputType = inputs.get(0).type().get();
- OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType());
int dimensionIndex = 0;
for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
Tensor.Cell cell = cellIterator.next();
@@ -61,12 +59,9 @@ public class Reshape extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+ if ( ! allInputFunctionsPresent(2)) return null;
+
OrderedTensorType inputType = inputs.get(0).type().get();
TensorFunction inputFunction = inputs.get(0).function().get();
return reshape(inputFunction, inputType.type(), type.type());
@@ -80,9 +75,8 @@ public class Reshape extends IntermediateOperation {
}
public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
- if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) {
+ if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType)))
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
- }
// Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
// then use the dimension order of the new shape to roll back into a tensor.
@@ -96,20 +90,17 @@ public class Reshape extends IntermediateOperation {
TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
Generate transformTensor = new Generate(transformationType,
- new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
-
- TensorFunction outputFunction = new Reduce(
- new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
+ new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
- return outputFunction;
+ return new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
}
private static ExpressionNode unrollTensorExpression(TensorType type) {
- if (type.rank() == 0) {
+ if (type.rank() == 0)
return new ConstantNode(DoubleValue.zero);
- }
+
List<ExpressionNode> children = new ArrayList<>();
List<ArithmeticOperator> operators = new ArrayList<>();
int size = 1;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
index 361729a8c14..79f3012c327 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
@@ -19,11 +19,10 @@ public class Shape extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(1)) {
- return null;
- }
+ if ( ! allInputTypesPresent(1)) return null;
+
OrderedTensorType inputType = inputs.get(0).type().get();
- return new OrderedTensorType.Builder()
+ return new OrderedTensorType.Builder(resultValueType())
.add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size()))
.build();
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
index 2eeefcbe8a2..52d40144f61 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
@@ -25,9 +25,8 @@ public class Squeeze extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(1)) {
- return null;
- }
+ if ( ! allInputTypesPresent(1)) return null;
+
OrderedTensorType inputType = inputs.get(0).type().get();
squeezeDimensions = new ArrayList<>();
@@ -51,9 +50,8 @@ public class Squeeze extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputFunctionsPresent(1)) {
- return null;
- }
+ if ( ! allInputFunctionsPresent(1)) return null;
+
TensorFunction inputFunction = inputs.get(0).function().get();
return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
}
@@ -73,7 +71,7 @@ public class Squeeze extends IntermediateOperation {
}
private OrderedTensorType reducedType(OrderedTensorType inputType) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
if ( ! squeezeDimensions.contains(dimension.name())) {
builder.add(dimension);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
index cb838cd67b1..a07c0fdf4dc 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
@@ -51,7 +51,7 @@ class GraphImporter {
String nodeName = node.getName();
String modelName = graph.name();
int nodePort = IntermediateOperation.indexPartOf(nodeName);
- OrderedTensorType nodeType = TypeConverter.fromTensorFlowType(node);
+ OrderedTensorType nodeType = TypeConverter.typeFrom(node);
AttributeConverter attributes = AttributeConverter.convert(node);
switch (node.getOp().toLowerCase()) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
index 6c92ffa6055..9cba388d00e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
@@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import org.tensorflow.DataType;
import org.tensorflow.framework.TensorProto;
import java.nio.ByteBuffer;
@@ -27,7 +28,7 @@ public class TensorConverter {
}
private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
- TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix);
+ TensorType type = TypeConverter.typeFrom(tfTensor, dimensionPrefix);
Values values = readValuesOf(tfTensor);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
for (int i = 0; i < values.size(); i++)
@@ -53,16 +54,6 @@ public class TensorConverter {
return builder.build();
}
- private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) {
- TensorType.Builder b = new TensorType.Builder();
- int dimensionIndex = 0;
- for (long dimensionSize : shape) {
- if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
- b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize);
- }
- return b.build();
- }
-
public static Long tensorSize(TensorType type) {
Long size = 1L;
for (TensorType.Dimension dimension : type.dimensions()) {
@@ -85,7 +76,7 @@ public class TensorConverter {
case INT64: return new LongValues(tfTensor);
}
throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
- tfTensor.dataType() + " to a Vespa tensor");
+ tfTensor.dataType() + " to a Vespa tensor");
}
private static Values readValuesOf(TensorProto tensorProto) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
index 63a605ce97a..d8ddb01b650 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
@@ -5,11 +5,10 @@ package ai.vespa.rankingexpression.importer.tensorflow;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.DataType;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorShapeProto;
-import java.util.List;
-
/**
* Converts and verifies TensorFlow tensor types into Vespa tensor types.
*
@@ -22,7 +21,7 @@ class TypeConverter {
if (shape != null) {
if (shape.getDimCount() != type.rank()) {
throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
- "does not match Vespa shape");
+ "does not match Vespa shape");
}
for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) {
int vespaIndex = type.dimensionMap(tensorFlowIndex);
@@ -30,33 +29,16 @@ class TypeConverter {
TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
- "does not match Vespa dimensions");
+ "does not match Vespa dimensions");
}
}
}
}
- private static TensorShapeProto tensorFlowShape(NodeDef node) {
- AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
- if (attrValueList == null) {
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "does not exist");
- }
- if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "is not of expected type");
- }
- List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
- return shapeList.get(0); // support multiple outputs?
- }
-
- static OrderedTensorType fromTensorFlowType(NodeDef node) {
- return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
- }
-
- private static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ static OrderedTensorType typeFrom(NodeDef node) {
+ String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
TensorShapeProto shape = tensorFlowShape(node);
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(tensorFlowValueType(node)));
for (int i = 0; i < shape.getDimCount(); ++ i) {
String dimensionName = dimensionPrefix + i;
TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
@@ -69,4 +51,71 @@ class TypeConverter {
return builder.build();
}
+ static TensorType typeFrom(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
+ TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType()));
+ int dimensionIndex = 0;
+ for (long dimensionSize : tfTensor.shape()) {
+ if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
+ b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize);
+ }
+ return b.build();
+ }
+
+ private static TensorShapeProto tensorFlowShape(NodeDef node) {
+ AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
+ if (attrValueList == null)
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "does not exist");
+ if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST)
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "is not of expected type");
+
+ return attrValueList.getList().getShape(0); // support multiple outputs?
+ }
+
+ private static DataType tensorFlowValueType(NodeDef node) {
+ AttrValue attrValueList = node.getAttrMap().get("dtypes");
+ if (attrValueList == null)
+ return DataType.DT_DOUBLE; // default. This will usually (always?) be used. TODO: How can we do better?
+ if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST)
+ return DataType.DT_DOUBLE; // default
+
+ return attrValueList.getList().getType(0); // support multiple outputs?
+ }
+
+ private static TensorType.Value toValueType(DataType dataType) {
+ switch (dataType) {
+ case DT_FLOAT: return TensorType.Value.FLOAT;
+ case DT_DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case DT_BOOL: return TensorType.Value.FLOAT;
+ case DT_BFLOAT16: return TensorType.Value.FLOAT;
+ case DT_HALF: return TensorType.Value.FLOAT;
+ case DT_INT8: return TensorType.Value.FLOAT;
+ case DT_INT16: return TensorType.Value.FLOAT;
+ case DT_INT32: return TensorType.Value.FLOAT;
+ case DT_INT64: return TensorType.Value.DOUBLE;
+ case DT_UINT8: return TensorType.Value.FLOAT;
+ case DT_UINT16: return TensorType.Value.FLOAT;
+ case DT_UINT32: return TensorType.Value.FLOAT;
+ case DT_UINT64: return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
+ private static TensorType.Value toValueType(org.tensorflow.DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case UINT8: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java
index afe699d6e05..61f332327be 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java
@@ -13,9 +13,10 @@ public class OrderedTensorTypeTestCase {
@Test
public void testToFromSpec() {
String spec = "tensor(b[],c{},a[3])";
+ String orderedSpec = "tensor(a[3],b[],c{})";
OrderedTensorType type = OrderedTensorType.fromSpec(spec);
- assertEquals(spec, type.toString());
- assertEquals("tensor(a[3],b[],c{})", type.type().toString());
+ assertEquals(orderedSpec, type.toString());
+ assertEquals(orderedSpec, type.type().toString());
}
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index 424e4d6c57c..07814687dc6 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -43,14 +43,14 @@ public class OnnxMnistSoftmaxImportTestCase {
// Check inputs
assertEquals(1, model.inputs().size());
assertTrue(model.inputs().containsKey("Placeholder"));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), model.inputs().get("Placeholder"));
// Check signature
ImportedMlFunction output = model.defaultSignature().outputFunction("add", "add");
assertNotNull(output);
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
output.expression());
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"),
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"),
model.inputs().get(model.defaultSignature().inputs().get("Placeholder")));
assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString());
}
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 79c633b9617..b8c51f4e33d 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -886,6 +886,7 @@
"public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()",
"public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()",
"public final com.yahoo.tensor.TensorType tensorTypeArgument()",
+ "public final com.yahoo.tensor.TensorType$Value optionalTensorValueTypeParameter()",
"public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder)",
"public final java.lang.String tensorFunctionName()",
"public final com.yahoo.searchlib.rankingexpression.rule.Function unaryFunctionName()",
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 2f173ad0266..c83de4ced0a 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -598,9 +598,12 @@ Reduce.Aggregator tensorReduceAggregator() :
TensorType tensorTypeArgument() :
{
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder;
+ TensorType.Value valueType;
}
{
+ valueType = optionalTensorValueTypeParameter()
+ { builder = new TensorType.Builder(valueType); }
<LBRACE>
( tensorTypeDimension(builder) ) ?
( <COMMA> tensorTypeDimension(builder) ) *
@@ -608,6 +611,15 @@ TensorType tensorTypeArgument() :
{ return builder.build(); }
}
+TensorType.Value optionalTensorValueTypeParameter() :
+{
+ String valueType = "double";
+}
+{
+ ( <LT> valueType = identifier() <GT> )?
+ { return TensorTypeParser.toValueType(valueType); }
+}
+
// NOTE: Only indexed bound dimensions are parsed currently, as that is what we need
void tensorTypeDimension(TensorType.Builder builder) :
{
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index f2122bb5da9..f7e38862883 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -238,6 +238,8 @@ public class EvaluationTestCase {
"{{x:0}:1}", "{}", "{{y:0,z:0}:1}");
tester.assertEvaluates("tensor(x{}):{}",
"tensor0 * tensor1", "{ {x:0}:3 }", "tensor(x{}):{ {x:1}:5 }");
+ tester.assertEvaluates("tensor<float>(x{}):{}",
+ "tensor0 * tensor1", "{ {x:0}:3 }", "tensor<float>(x{}):{ {x:1}:5 }");
tester.assertEvaluates("{ {x:0}:15 }",
"tensor0 * tensor1", "{ {x:0}:3 }", "{ {x:0}:5 }");
tester.assertEvaluates("{ {x:0,y:0}:15 }",
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
index ba0db4de5e1..488930a8eb9 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
@@ -40,7 +40,7 @@ public class EvaluationTester {
int argumentIndex = 0;
for (String argumentString : tensorArgumentStrings) {
Tensor argument;
- if (argumentString.startsWith("tensor(")) // explicitly decided type
+ if (argumentString.startsWith("tensor")) // explicitly decided type
argument = Tensor.from(argumentString);
else // use mappedTensors+dimensions in tensor to decide type
argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString);
diff --git a/searchlib/src/tests/aggregator/perdocexpr.cpp b/searchlib/src/tests/aggregator/perdocexpr.cpp
index 66d2e48194d..1b85fb8f427 100644
--- a/searchlib/src/tests/aggregator/perdocexpr.cpp
+++ b/searchlib/src/tests/aggregator/perdocexpr.cpp
@@ -1325,6 +1325,75 @@ TEST("testAggregationResults") {
FloatResultNode(15.54));
}
+TEST("test Average over integer") {
+ AggregationResult::Configure conf;
+ AverageAggregationResult avg;
+ avg.setExpression(createScalarInt(I4)).select(conf, conf);
+ avg.aggregate(0, 0);
+ EXPECT_EQUAL(I4, avg.getAverage().getInteger());
+}
+
+TEST("test Average over float") {
+ AggregationResult::Configure conf;
+ AverageAggregationResult avg;
+ avg.setExpression(createScalarFloat(I4)).select(conf, conf);
+ avg.aggregate(0, 0);
+ EXPECT_EQUAL(I4, avg.getAverage().getInteger());
+}
+
+TEST("test Average over numeric string") {
+ AggregationResult::Configure conf;
+ AverageAggregationResult avg;
+ avg.setExpression(createScalarString("7.8")).select(conf, conf);
+ avg.aggregate(0, 0);
+ EXPECT_EQUAL(7.8, avg.getAverage().getFloat());
+}
+
+TEST("test Average over non-numeric string") {
+ AggregationResult::Configure conf;
+ AverageAggregationResult avg;
+ avg.setExpression(createScalarString("ABC")).select(conf, conf);
+ avg.aggregate(0, 0);
+ EXPECT_EQUAL(0, avg.getAverage().getInteger());
+}
+
+TEST("test Sum over integer") {
+ AggregationResult::Configure conf;
+ SumAggregationResult sum;
+ sum.setExpression(createScalarInt(I4)).select(conf, conf);
+ sum.aggregate(0, 0);
+ sum.aggregate(0, 0);
+ EXPECT_EQUAL(I4*2, sum.getSum().getInteger());
+}
+
+TEST("test Sum over float") {
+ AggregationResult::Configure conf;
+ SumAggregationResult sum;
+ sum.setExpression(createScalarFloat(I4)).select(conf, conf);
+ sum.aggregate(0, 0);
+ sum.aggregate(0, 0);
+ EXPECT_EQUAL(I4*2, sum.getSum().getInteger());
+}
+
+TEST("test Sum over numeric string") {
+ AggregationResult::Configure conf;
+ SumAggregationResult sum;
+ sum.setExpression(createScalarString("7.8")).select(conf, conf);
+ sum.aggregate(0, 0);
+ sum.aggregate(0, 0);
+ EXPECT_EQUAL(7.8*2, sum.getSum().getFloat());
+}
+
+TEST("test Sum over non-numeric string") {
+ AggregationResult::Configure conf;
+ SumAggregationResult sum;
+ sum.setExpression(createScalarString("ABC")).select(conf, conf);
+ sum.aggregate(0, 0);
+ sum.aggregate(0, 0);
+ EXPECT_EQUAL(0, sum.getSum().getInteger());
+}
+
+
TEST("testGrouping") {
AttributeGuard attr1 = createInt64Attribute();
ExpressionNode::UP result1(new CountAggregationResult());
diff --git a/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp b/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp
index 42ce9725f91..54c77fb25a7 100644
--- a/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp
+++ b/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp
@@ -6,6 +6,10 @@
#include <vespa/searchlib/fef/test/ftlib.h>
#include <vespa/searchlib/fef/test/rankresult.h>
#include <vespa/searchlib/fef/test/dummy_dependency_handler.h>
+#include <vespa/eval/tensor/tensor.h>
+#include <vespa/eval/tensor/serialization/typed_binary_format.h>
+#include <vespa/vespalib/objects/nbostream.h>
+#include <vespa/eval/tensor/dense/dense_tensor.h>
using namespace search;
using namespace search::attribute;
@@ -104,7 +108,26 @@ struct ArrayFixture : FixtureBase {
}
template <typename ExpectedType>
- void check_prepare_state_output(const vespalib::string& input_vector) {
+ void check_prepare_state_output(const vespalib::tensor::Tensor & tensor, vespalib::tensor::SerializeFormat format, const ExpectedType & expected) {
+ vespalib::nbostream os;
+ vespalib::tensor::TypedBinaryFormat::serialize(os, tensor, format);
+ vespalib::string input_vector(os.c_str(), os.size());
+ check_prepare_state_output(".tensor", input_vector, expected);
+ }
+
+ template <typename ExpectedType>
+ void check_prepare_state_output(const vespalib::string& input_vector, const ExpectedType & expected) {
+ check_prepare_state_output("", input_vector, expected);
+ }
+ template <typename T>
+ static void verify(const dotproduct::ArrayParam<T> & a, const dotproduct::ArrayParam<T> & b) {
+ ASSERT_EQUAL(a.values.size(), b.values.size());
+ for (size_t i(0); i < a.values.size(); i++) {
+ EXPECT_EQUAL(a.values[i], b.values[i]);
+ }
+ }
+ template <typename ExpectedType>
+ void check_prepare_state_output(const vespalib::string & postfix, const vespalib::string& input_vector, const ExpectedType & expected) {
FtFeatureTest feature(_factory, "");
DotProductBlueprint bp;
DummyDependencyHandler dependency_handler(bp);
@@ -116,7 +139,7 @@ struct ArrayFixture : FixtureBase {
FieldType::ATTRIBUTE, schema::CollectionType::ARRAY, imported_attr->getName());
bp.setup(feature.getIndexEnv(), params);
- feature.getQueryEnv().getProperties().add("dotProduct.fancyvector", input_vector);
+ feature.getQueryEnv().getProperties().add("dotProduct.fancyvector" + postfix, input_vector);
auto& obj_store = feature.getQueryEnv().getObjectStore();
bp.prepareSharedState(feature.getQueryEnv(), obj_store);
// Resulting name is very implementation defined. But at least the tests will break if it changes.
@@ -124,13 +147,12 @@ struct ArrayFixture : FixtureBase {
ASSERT_TRUE(parsed != nullptr);
const auto* as_object = dynamic_cast<const ExpectedType*>(parsed);
ASSERT_TRUE(as_object != nullptr);
- // We don't test the parsed output values here; that's the responsibility of other tests.
+ verify(expected, *as_object);
}
- void check_all_float_executions(feature_t expected,
- const vespalib::string& vector,
- DocId doc_id,
- const vespalib::string& shared_param = "") {
+ void check_all_float_executions(feature_t expected, const vespalib::string& vector,
+ DocId doc_id, const vespalib::string& shared_param = "")
+ {
check_executions<double>([this](auto float_type){ this->setup_float_mappings(float_type); },
{{BasicType::FLOAT, BasicType::DOUBLE}},
expected, vector, doc_id, shared_param);
@@ -155,22 +177,46 @@ TEST_F("Zero-length float/double array query vector evaluates to zero", ArrayFix
TEST_F("prepareSharedState emits i64 vector for i32 imported attribute", ArrayFixture) {
f.setup_integer_mappings(BasicType::INT32);
- f.template check_prepare_state_output<dotproduct::ArrayParam<int64_t>>("[101 202 303]");
+ f.template check_prepare_state_output("[101 202 303]", dotproduct::ArrayParam<int64_t>({101, 202, 303}));
}
TEST_F("prepareSharedState emits i64 vector for i64 imported attribute", ArrayFixture) {
f.setup_integer_mappings(BasicType::INT64);
- f.template check_prepare_state_output<dotproduct::ArrayParam<int64_t>>("[101 202 303]");
+ f.template check_prepare_state_output("[101 202 303]", dotproduct::ArrayParam<int64_t>({101, 202, 303}));
}
TEST_F("prepareSharedState emits double vector for float imported attribute", ArrayFixture) {
f.setup_float_mappings(BasicType::FLOAT);
- f.template check_prepare_state_output<dotproduct::ArrayParam<double>>("[10.1 20.2 30.3]");
+ f.template check_prepare_state_output("[10.1 20.2 30.3]", dotproduct::ArrayParam<double>({10.1, 20.2, 30.3}));
}
TEST_F("prepareSharedState emits double vector for double imported attribute", ArrayFixture) {
f.setup_float_mappings(BasicType::DOUBLE);
- f.template check_prepare_state_output<dotproduct::ArrayParam<double>>("[10.1 20.2 30.3]");
+ f.template check_prepare_state_output("[10.1 20.2 30.3]", dotproduct::ArrayParam<double>({10.1, 20.2, 30.3}));
+}
+
+TEST_F("prepareSharedState handles tensor as float from tensor for double imported attribute", ArrayFixture) {
+ f.setup_float_mappings(BasicType::DOUBLE);
+ vespalib::tensor::DenseTensor tensor(vespalib::eval::ValueType::from_spec("tensor(x[3])"), {10.1, 20.2, 30.3});
+ f.template check_prepare_state_output(tensor, vespalib::tensor::SerializeFormat::FLOAT, dotproduct::ArrayParam<double>({10.1, 20.2, 30.3}));
+}
+
+TEST_F("prepareSharedState handles tensor as double from tensor for double imported attribute", ArrayFixture) {
+ f.setup_float_mappings(BasicType::DOUBLE);
+ vespalib::tensor::DenseTensor tensor(vespalib::eval::ValueType::from_spec("tensor(x[3])"), {10.1, 20.2, 30.3});
+ f.template check_prepare_state_output(tensor, vespalib::tensor::SerializeFormat::DOUBLE, dotproduct::ArrayParam<double>({10.1, 20.2, 30.3}));
+}
+
+TEST_F("prepareSharedState handles tensor as float from tensor for float imported attribute", ArrayFixture) {
+ f.setup_float_mappings(BasicType::FLOAT);
+ vespalib::tensor::DenseTensor tensor(vespalib::eval::ValueType::from_spec("tensor(x[3])"), {10.1, 20.2, 30.3});
+ f.template check_prepare_state_output(tensor, vespalib::tensor::SerializeFormat::FLOAT, dotproduct::ArrayParam<float>({10.1, 20.2, 30.3}));
+}
+
+TEST_F("prepareSharedState handles tensor as double from tensor for float imported attribute", ArrayFixture) {
+ f.setup_float_mappings(BasicType::FLOAT);
+ vespalib::tensor::DenseTensor tensor(vespalib::eval::ValueType::from_spec("tensor(x[3])"), {10.1, 20.2, 30.3});
+ f.template check_prepare_state_output(tensor, vespalib::tensor::SerializeFormat::DOUBLE, dotproduct::ArrayParam<float>({10.1, 20.2, 30.3}));
}
TEST_F("Dense i32/i64 array dot product can be evaluated with pre-parsed object parameter", ArrayFixture) {
diff --git a/searchlib/src/tests/grouping/grouping_test.cpp b/searchlib/src/tests/grouping/grouping_test.cpp
index fea18619ef9..0750d30f60d 100644
--- a/searchlib/src/tests/grouping/grouping_test.cpp
+++ b/searchlib/src/tests/grouping/grouping_test.cpp
@@ -313,10 +313,9 @@ Test::testAggregationSimple()
ctx.add(FloatAttrBuilder("float").add(3).add(7).add(15).sp());
ctx.add(StringAttrBuilder("string").add("3").add("7").add("15").sp());
- char strsum[3] = {-101, '5', 0};
- testAggregationSimpleSum(ctx, SumAggregationResult(), Int64ResultNode(25), FloatResultNode(25), StringResultNode(strsum));
- testAggregationSimpleSum(ctx, MinAggregationResult(), Int64ResultNode(3), FloatResultNode(3), StringResultNode("15"));
- testAggregationSimpleSum(ctx, MaxAggregationResult(), Int64ResultNode(15), FloatResultNode(15), StringResultNode("7"));
+ TEST_DO(testAggregationSimpleSum(ctx, SumAggregationResult(), Int64ResultNode(25), FloatResultNode(25), StringResultNode("25")));
+ TEST_DO(testAggregationSimpleSum(ctx, MinAggregationResult(), Int64ResultNode(3), FloatResultNode(3), StringResultNode("15")));
+ TEST_DO(testAggregationSimpleSum(ctx, MaxAggregationResult(), Int64ResultNode(15), FloatResultNode(15), StringResultNode("7")));
}
#define MU std::make_unique
@@ -630,6 +629,14 @@ createAggr(SingleResultNode::UP r, ExpressionNode::UP e) {
return aggr;
}
+template<typename T>
+ExpressionNode::UP
+createNumAggr(NumericResultNode::UP r, ExpressionNode::UP e) {
+ std::unique_ptr<T> aggr = MU<T>(std::move(r));
+ aggr->setExpression(std::move(e));
+ return aggr;
+}
+
void
Test::testAggregationGroupCapping()
{
@@ -680,13 +687,13 @@ Test::testAggregationGroupCapping()
Group expect;
expect.addChild(Group().setId(Int64ResultNode(7)).setRank(RawRank(7))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr")))
.addOrderBy(MU<AggregationRefNode>(0), false))
.addChild(Group().setId(Int64ResultNode(8)).setRank(RawRank(8))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr")))
.addOrderBy(MU<AggregationRefNode>(0), false))
.addChild(Group().setId(Int64ResultNode(9)).setRank(RawRank(9))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr")))
.addOrderBy(MU<AggregationRefNode>(0), false));
EXPECT_TRUE(testAggregation(ctx, request, expect));
@@ -701,13 +708,13 @@ Test::testAggregationGroupCapping()
Group expect = Group()
.addChild(Group().setId(Int64ResultNode(1)).setRank(RawRank(1))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(1), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(1), MU<AttributeNode>("attr")))
.addOrderBy(MU<AggregationRefNode>(0), true))
.addChild(Group().setId(Int64ResultNode(2)).setRank(RawRank(2))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(2), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(2), MU<AttributeNode>("attr")))
.addOrderBy(MU<AggregationRefNode>(0), true))
.addChild(Group().setId(Int64ResultNode(3)).setRank(RawRank(3))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(3), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(3), MU<AttributeNode>("attr")))
.addOrderBy(MU<AggregationRefNode>(0), true));
EXPECT_TRUE(testAggregation(ctx, request, expect));
@@ -726,13 +733,13 @@ Test::testAggregationGroupCapping()
Group expect;
expect.addChild(Group().setId(Int64ResultNode(7)).setRank(RawRank(7))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr")))
.addOrderBy(AddFunctionNode().appendArg(MU<AggregationRefNode>(0)).appendArg(MU<ConstantNode>(MU<Int64ResultNode>(3))).setResult(Int64ResultNode(10)), false))
.addChild(Group().setId(Int64ResultNode(8)).setRank(RawRank(8))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr")))
.addOrderBy(AddFunctionNode().appendArg(MU<AggregationRefNode>(0)).appendArg(MU<ConstantNode>(MU<Int64ResultNode>(3))).setResult(Int64ResultNode(11)), false))
.addChild(Group().setId(Int64ResultNode(9)).setRank(RawRank(9))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr")))
.addOrderBy(AddFunctionNode().appendArg(MU<AggregationRefNode>(0)).appendArg(MU<ConstantNode>(MU<Int64ResultNode>(3))).setResult(Int64ResultNode(12)), false));
EXPECT_TRUE(testAggregation(ctx, request, expect));
diff --git a/searchlib/src/tests/groupingengine/groupingengine_test.cpp b/searchlib/src/tests/groupingengine/groupingengine_test.cpp
index 3920667c1d6..da9b8d62305 100644
--- a/searchlib/src/tests/groupingengine/groupingengine_test.cpp
+++ b/searchlib/src/tests/groupingengine/groupingengine_test.cpp
@@ -615,6 +615,14 @@ createAggr(SingleResultNode::UP r, ExpressionNode::UP e) {
return aggr;
}
+template<typename T>
+ExpressionNode::UP
+createNumAggr(NumericResultNode::UP r, ExpressionNode::UP e) {
+ std::unique_ptr<T> aggr = MU<T>(std::move(r));
+ aggr->setExpression(std::move(e));
+ return aggr;
+}
+
void
Test::testAggregationGroupCapping()
{
@@ -670,13 +678,13 @@ Test::testAggregationGroupCapping()
Group expect;
expect.setId(NullResultNode())
.addChild(Group().setId(Int64ResultNode(7)).setRank(RawRank(7))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr")))
.addOrderBy(MU<AggregationRefNode>(0), false))
.addChild(Group().setId(Int64ResultNode(8)).setRank(RawRank(8))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr")))
.addOrderBy(MU<AggregationRefNode>(0), false))
.addChild(Group().setId(Int64ResultNode(9)).setRank(RawRank(9))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr")))
.addOrderBy(MU<AggregationRefNode>(0), false));
EXPECT_TRUE(testAggregation(ctx, request, expect));
@@ -693,13 +701,13 @@ Test::testAggregationGroupCapping()
Group expect;
expect.setId(NullResultNode())
.addChild(Group().setId(Int64ResultNode(1)).setRank(RawRank(1))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(1), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(1), MU<AttributeNode>("attr")))
.addOrderBy(MU<AggregationRefNode>(0), true))
.addChild(Group().setId(Int64ResultNode(2)).setRank(RawRank(2))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(2), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(2), MU<AttributeNode>("attr")))
.addOrderBy(MU<AggregationRefNode>(0), true))
.addChild(Group().setId(Int64ResultNode(3)).setRank(RawRank(3))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(3), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(3), MU<AttributeNode>("attr")))
.addOrderBy(MU<AggregationRefNode>(0), true));
EXPECT_TRUE(testAggregation(ctx, request, expect));
@@ -718,13 +726,13 @@ Test::testAggregationGroupCapping()
Group expect = Group()
.addChild(Group().setId(Int64ResultNode(7)).setRank(RawRank(7))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr")))
.addOrderBy(AddFunctionNode().appendArg(MU<AggregationRefNode>(0)).appendArg(MU<ConstantNode>(MU<Int64ResultNode>(3))).setResult(Int64ResultNode(10)), false))
.addChild(Group().setId(Int64ResultNode(8)).setRank(RawRank(8))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr")))
.addOrderBy(AddFunctionNode().appendArg(MU<AggregationRefNode>(0)).appendArg(MU<ConstantNode>(MU<Int64ResultNode>(3))).setResult(Int64ResultNode(11)), false))
.addChild(Group().setId(Int64ResultNode(9)).setRank(RawRank(9))
- .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr")))
+ .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr")))
.addOrderBy(AddFunctionNode().appendArg(MU<AggregationRefNode>(0)).appendArg(MU<ConstantNode>(MU<Int64ResultNode>(3))).setResult(Int64ResultNode(12)), false));
EXPECT_TRUE(testAggregation(ctx, request, expect));
diff --git a/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp b/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp
index d40cdd5f13e..d6eb3a033a2 100644
--- a/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp
+++ b/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp
@@ -17,6 +17,17 @@ bool isReady(const ResultNode *myRes, const ResultNode &ref) {
return (myRes != 0 && myRes->getClass().id() == ref.getClass().id());
}
+template<typename Wanted, typename Fallback>
+std::unique_ptr<Wanted>
+createAndEnsureWanted(const ResultNode & result) {
+ std::unique_ptr<ResultNode> tmp = result.createBaseType();
+ if (dynamic_cast<Wanted *>(tmp.get()) != nullptr) {
+ return std::unique_ptr<Wanted>(static_cast<Wanted *>(tmp.release()));
+ } else {
+ return std::make_unique<Fallback>();
+ }
+}
+
} // namespace search::aggregation::<unnamed>
@@ -38,14 +49,14 @@ IMPLEMENT_AGGREGATIONRESULT(ExpressionCountAggregationResult, AggregationResult)
IMPLEMENT_AGGREGATIONRESULT(StandardDeviationAggregationResult, AggregationResult);
AggregationResult::AggregationResult() :
- _expressionTree(new ExpressionTree()),
+ _expressionTree(std::make_shared<ExpressionTree>()),
_tag(-1)
{ }
AggregationResult::AggregationResult(const AggregationResult &) = default;
AggregationResult & AggregationResult::operator = (const AggregationResult &) = default;
-AggregationResult::~AggregationResult() { }
+AggregationResult::~AggregationResult() = default;
void
AggregationResult::aggregate(const document::Document & doc, HitRank rank) {
@@ -66,14 +77,16 @@ AggregationResult::aggregate(DocId docId, HitRank rank) {
}
}
-bool AggregationResult::Configure::check(const vespalib::Identifiable &obj) const
+bool
+AggregationResult::Configure::check(const vespalib::Identifiable &obj) const
{
return obj.inherits(AggregationResult::classId);
}
-void AggregationResult::Configure::execute(vespalib::Identifiable &obj)
+void
+AggregationResult::Configure::execute(vespalib::Identifiable &obj)
{
- AggregationResult & a(static_cast<AggregationResult &>(obj));
+ auto & a(static_cast<AggregationResult &>(obj));
a.prepare();
}
@@ -85,37 +98,40 @@ AggregationResult::setExpression(ExpressionNode::UP expr)
return *this;
}
-void CountAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
+void
+CountAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
{
(void) result;
(void) useForInit;
}
-void SumAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
+void
+SumAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
{
if (isReady(_sum.get(), result)) {
return;
}
- _sum.reset(dynamic_cast<SingleResultNode *>(result.createBaseType().release()));
+ _sum = createAndEnsureWanted<NumericResultNode, FloatResultNode>(result);
if ( useForInit ) {
_sum->set(result);
}
}
-MinAggregationResult::MinAggregationResult() : AggregationResult() { }
+MinAggregationResult::MinAggregationResult() = default;
MinAggregationResult::MinAggregationResult(const ResultNode::CP &result)
: AggregationResult()
{
setResult(result);
}
-MinAggregationResult::~MinAggregationResult() { }
+MinAggregationResult::~MinAggregationResult() = default;
-void MinAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
+void
+MinAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
{
if (isReady(_min.get(), result)) {
return;
}
- _min.reset(dynamic_cast<SingleResultNode *>(result.createBaseType().release()));
+ _min = createAndEnsureWanted<SingleResultNode, FloatResultNode>(result);
if ( !useForInit ) {
_min->setMax();
} else {
@@ -123,19 +139,20 @@ void MinAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
}
}
-MaxAggregationResult::MaxAggregationResult() : AggregationResult(), _max() { }
+MaxAggregationResult::MaxAggregationResult() = default;
MaxAggregationResult::MaxAggregationResult(const SingleResultNode & max)
: AggregationResult(),
_max(max)
{ }
-MaxAggregationResult::~MaxAggregationResult() { }
+MaxAggregationResult::~MaxAggregationResult() = default;
-void MaxAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
+void
+MaxAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
{
if (isReady(_max.get(), result)) {
return;
}
- _max.reset(dynamic_cast<SingleResultNode *>(result.createBaseType().release()));
+ _max = createAndEnsureWanted<SingleResultNode, FloatResultNode>(result);
if ( !useForInit ) {
_max->setMin(); ///Should figure out how to set min too for float.
} else {
@@ -143,29 +160,33 @@ void MaxAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
}
}
-void AverageAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
+void
+AverageAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
{
if (isReady(_sum.get(), result)) {
return;
}
- _sum.reset(dynamic_cast<NumericResultNode *>(result.createBaseType().release()));
+ _sum = createAndEnsureWanted<NumericResultNode, FloatResultNode>(result);
if ( useForInit ) {
_sum->set(result);
}
}
-void XorAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
+void
+XorAggregationResult::onPrepare(const ResultNode & result, bool useForInit)
{
(void) result;
(void) useForInit;
}
-void SumAggregationResult::onMerge(const AggregationResult & b)
+void
+SumAggregationResult::onMerge(const AggregationResult & b)
{
_sum->add(*static_cast<const SumAggregationResult &>(b)._sum);
}
-void SumAggregationResult::onAggregate(const ResultNode & result)
+void
+SumAggregationResult::onAggregate(const ResultNode & result)
{
if (result.isMultiValue()) {
static_cast<const ResultNodeVector &>(result).flattenSum(*_sum);
@@ -174,17 +195,20 @@ void SumAggregationResult::onAggregate(const ResultNode & result)
}
}
-void SumAggregationResult::onReset()
+void
+SumAggregationResult::onReset()
{
- _sum.reset(static_cast<SingleResultNode *>(_sum->getClass().create()));
+ _sum.reset(static_cast<NumericResultNode *>(_sum->getClass().create()));
}
-void CountAggregationResult::onMerge(const AggregationResult & b)
+void
+CountAggregationResult::onMerge(const AggregationResult & b)
{
_count.add(static_cast<const CountAggregationResult &>(b)._count);
}
-void CountAggregationResult::onAggregate(const ResultNode & result)
+void
+CountAggregationResult::onAggregate(const ResultNode & result)
{
if (result.isMultiValue()) {
_count += static_cast<const ResultNodeVector &>(result).size();
@@ -193,17 +217,20 @@ void CountAggregationResult::onAggregate(const ResultNode & result)
}
}
-void CountAggregationResult::onReset()
+void
+CountAggregationResult::onReset()
{
setCount(0);
}
-void MaxAggregationResult::onMerge(const AggregationResult & b)
+void
+MaxAggregationResult::onMerge(const AggregationResult & b)
{
_max->max(*static_cast<const MaxAggregationResult &>(b)._max);
}
-void MaxAggregationResult::onAggregate(const ResultNode & result)
+void
+MaxAggregationResult::onAggregate(const ResultNode & result)
{
if (result.isMultiValue()) {
static_cast<const ResultNodeVector &>(result).flattenMax(*_max);
@@ -212,18 +239,21 @@ void MaxAggregationResult::onAggregate(const ResultNode & result)
}
}
-void MaxAggregationResult::onReset()
+void
+MaxAggregationResult::onReset()
{
_max.reset(static_cast<SingleResultNode *>(_max->getClass().create()));
_max->setMin();
}
-void MinAggregationResult::onMerge(const AggregationResult & b)
+void
+MinAggregationResult::onMerge(const AggregationResult & b)
{
_min->min(*static_cast<const MinAggregationResult &>(b)._min);
}
-void MinAggregationResult::onAggregate(const ResultNode & result)
+void
+MinAggregationResult::onAggregate(const ResultNode & result)
{
if (result.isMultiValue()) {
static_cast<const ResultNodeVector &>(result).flattenMin(*_min);
@@ -232,22 +262,25 @@ void MinAggregationResult::onAggregate(const ResultNode & result)
}
}
-void MinAggregationResult::onReset()
+void
+MinAggregationResult::onReset()
{
_min.reset(static_cast<SingleResultNode *>(_min->getClass().create()));
_min->setMax();
}
-AverageAggregationResult::~AverageAggregationResult() {}
+AverageAggregationResult::~AverageAggregationResult() = default;
-void AverageAggregationResult::onMerge(const AggregationResult & b)
+void
+AverageAggregationResult::onMerge(const AggregationResult & b)
{
- const AverageAggregationResult & avg(static_cast<const AverageAggregationResult &>(b));
+ const auto & avg(static_cast<const AverageAggregationResult &>(b));
_sum->add(*avg._sum);
_count += avg._count;
}
-void AverageAggregationResult::onAggregate(const ResultNode & result)
+void
+AverageAggregationResult::onAggregate(const ResultNode & result)
{
if (result.isMultiValue()) {
static_cast<const ResultNodeVector &>(result).flattenSum(*_sum);
@@ -258,13 +291,15 @@ void AverageAggregationResult::onAggregate(const ResultNode & result)
}
}
-void AverageAggregationResult::onReset()
+void
+AverageAggregationResult::onReset()
{
_count = 0;
_sum.reset(static_cast<NumericResultNode *>(_sum->getClass().create()));
}
-const NumericResultNode & AverageAggregationResult::getAverage() const
+const NumericResultNode &
+AverageAggregationResult::getAverage() const
{
_averageScratchPad = _sum;
if ( _count > 0 ) {
@@ -275,12 +310,14 @@ const NumericResultNode & AverageAggregationResult::getAverage() const
return *_averageScratchPad;
}
-void XorAggregationResult::onMerge(const AggregationResult & b)
+void
+XorAggregationResult::onMerge(const AggregationResult & b)
{
_xor.xorOp(static_cast<const XorAggregationResult &>(b)._xor);
}
-void XorAggregationResult::onAggregate(const ResultNode & result)
+void
+XorAggregationResult::onAggregate(const ResultNode & result)
{
if (result.isMultiValue()) {
for (size_t i(0), m(static_cast<const ResultNodeVector &>(result).size()); i < m; i++) {
@@ -291,21 +328,24 @@ void XorAggregationResult::onAggregate(const ResultNode & result)
}
}
-void XorAggregationResult::onReset()
+void
+XorAggregationResult::onReset()
{
_xor = 0;
}
static FieldBase _G_tagField("tag");
-Serializer & AggregationResult::onSerialize(Serializer & os) const
+Serializer &
+AggregationResult::onSerialize(Serializer & os) const
{
return (os << *_expressionTree).put(_G_tagField, _tag);
}
-Deserializer & AggregationResult::onDeserialize(Deserializer & is)
+Deserializer &
+AggregationResult::onDeserialize(Deserializer & is)
{
- _expressionTree.reset(new ExpressionTree());
+ _expressionTree = std::make_shared<ExpressionTree>();
return (is >> *_expressionTree).get(_G_tagField, _tag);
}
@@ -315,18 +355,21 @@ AggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const
visit(visitor, "expression", _expressionTree);
}
-void AggregationResult::selectMembers(const vespalib::ObjectPredicate & predicate, vespalib::ObjectOperation & operation)
+void
+AggregationResult::selectMembers(const vespalib::ObjectPredicate & predicate, vespalib::ObjectOperation & operation)
{
_expressionTree->select(predicate,operation);
}
-Serializer & CountAggregationResult::onSerialize(Serializer & os) const
+Serializer &
+CountAggregationResult::onSerialize(Serializer & os) const
{
AggregationResult::onSerialize(os);
return _count.serialize(os);
}
-Deserializer & CountAggregationResult::onDeserialize(Deserializer & is)
+Deserializer &
+CountAggregationResult::onDeserialize(Deserializer & is)
{
AggregationResult::onDeserialize(is);
return _count.deserialize(is);
@@ -339,27 +382,27 @@ CountAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const
visit(visitor, "count", _count);
}
-Serializer & SumAggregationResult::onSerialize(Serializer & os) const
+Serializer &
+SumAggregationResult::onSerialize(Serializer & os) const
{
AggregationResult::onSerialize(os);
return os << _sum;
}
-Deserializer & SumAggregationResult::onDeserialize(Deserializer & is)
+Deserializer &
+SumAggregationResult::onDeserialize(Deserializer & is)
{
AggregationResult::onDeserialize(is);
return is >> _sum;
}
-SumAggregationResult::SumAggregationResult()
- : AggregationResult(),
- _sum()
-{ }
-SumAggregationResult::SumAggregationResult(SingleResultNode::UP sum)
+SumAggregationResult::SumAggregationResult() = default;
+
+SumAggregationResult::SumAggregationResult(NumericResultNode::UP sum)
: AggregationResult(),
_sum(sum.release())
{ }
-SumAggregationResult::~SumAggregationResult() {}
+SumAggregationResult::~SumAggregationResult() = default;
void
SumAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const
@@ -368,13 +411,15 @@ SumAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const
visit(visitor, "sum", _sum);
}
-Serializer & MinAggregationResult::onSerialize(Serializer & os) const
+Serializer &
+MinAggregationResult::onSerialize(Serializer & os) const
{
AggregationResult::onSerialize(os);
return os << _min;
}
-Deserializer & MinAggregationResult::onDeserialize(Deserializer & is)
+Deserializer &
+MinAggregationResult::onDeserialize(Deserializer & is)
{
AggregationResult::onDeserialize(is);
return is >> _min;
@@ -387,13 +432,15 @@ MinAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const
visit(visitor, "min", _min);
}
-Serializer & MaxAggregationResult::onSerialize(Serializer & os) const
+Serializer &
+MaxAggregationResult::onSerialize(Serializer & os) const
{
AggregationResult::onSerialize(os);
return os << _max;
}
-Deserializer & MaxAggregationResult::onDeserialize(Deserializer & is)
+Deserializer &
+MaxAggregationResult::onDeserialize(Deserializer & is)
{
AggregationResult::onDeserialize(is);
return is >> _max;
@@ -406,16 +453,19 @@ MaxAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const
visit(visitor, "max", _max);
}
-static FieldBase _G_countField("count");
-static FieldBase _G_sumField("sum");
+namespace {
+ FieldBase _G_countField("count");
+}
-Serializer & AverageAggregationResult::onSerialize(Serializer & os) const
+Serializer &
+AverageAggregationResult::onSerialize(Serializer & os) const
{
AggregationResult::onSerialize(os);
return os.put(_G_countField, _count) << _sum;
}
-Deserializer & AverageAggregationResult::onDeserialize(Deserializer & is)
+Deserializer &
+AverageAggregationResult::onDeserialize(Deserializer & is)
{
AggregationResult::onDeserialize(is);
return is.get(_G_countField, _count) >> _sum;
@@ -429,13 +479,15 @@ AverageAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const
visit(visitor, "sum", _sum);
}
-Serializer & XorAggregationResult::onSerialize(Serializer & os) const
+Serializer &
+XorAggregationResult::onSerialize(Serializer & os) const
{
AggregationResult::onSerialize(os);
return _xor.serialize(os);
}
-Deserializer & XorAggregationResult::onDeserialize(Deserializer & is)
+Deserializer &
+XorAggregationResult::onDeserialize(Deserializer & is)
{
AggregationResult::onDeserialize(is);
return _xor.deserialize(is);
@@ -451,7 +503,8 @@ XorAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const
namespace {
// Calculates the sum of all buckets.
template <int BucketBits, typename HashT>
-int calculateRank(const Sketch<BucketBits, HashT> &sketch) {
+int
+calculateRank(const Sketch<BucketBits, HashT> &sketch) {
if (sketch.getClassId() == SparseSketch<BucketBits, HashT>::classId) {
return static_cast<const SparseSketch<BucketBits, HashT>&>(sketch)
.getSize();
@@ -465,13 +518,14 @@ int calculateRank(const Sketch<BucketBits, HashT> &sketch) {
}
} // namespace
-void ExpressionCountAggregationResult::onMerge(const AggregationResult &r) {
- const ExpressionCountAggregationResult &result =
- Identifiable::cast<const ExpressionCountAggregationResult &>(r);
+void
+ExpressionCountAggregationResult::onMerge(const AggregationResult &r) {
+ const auto & result = Identifiable::cast<const ExpressionCountAggregationResult &>(r);
_hll.merge(result._hll);
_rank.set(calculateRank(_hll.getSketch()));
}
-void ExpressionCountAggregationResult::onAggregate(const ResultNode &result) {
+void
+ExpressionCountAggregationResult::onAggregate(const ResultNode &result) {
size_t hash = result.hash();
const unsigned int seed = 42;
hash = XXH32(&hash, sizeof(hash), seed);
@@ -479,36 +533,38 @@ void ExpressionCountAggregationResult::onAggregate(const ResultNode &result) {
// almost the same ordering as the actual estimates.
_rank += _hll.aggregate(hash);
}
-void ExpressionCountAggregationResult::onReset() {
+void
+ExpressionCountAggregationResult::onReset() {
_hll = HyperLogLog<PRECISION>();
_rank.set(0);
}
-Serializer &ExpressionCountAggregationResult::onSerialize(
- Serializer &os) const {
+Serializer &
+ExpressionCountAggregationResult::onSerialize(Serializer &os) const {
AggregationResult::onSerialize(os);
_hll.serialize(os);
return os;
}
-Deserializer &ExpressionCountAggregationResult::onDeserialize(
- Deserializer &is) {
+Deserializer &
+ExpressionCountAggregationResult::onDeserialize(Deserializer &is) {
AggregationResult::onDeserialize(is);
_hll.deserialize(is);
_rank.set(calculateRank(_hll.getSketch()));
return is;
}
-ExpressionCountAggregationResult::ExpressionCountAggregationResult() : AggregationResult(), _hll() { }
-ExpressionCountAggregationResult::~ExpressionCountAggregationResult() {}
+ExpressionCountAggregationResult::ExpressionCountAggregationResult() = default;
+ExpressionCountAggregationResult::~ExpressionCountAggregationResult() = default;
StandardDeviationAggregationResult::StandardDeviationAggregationResult()
- : AggregationResult(), _count(), _sum(), _sumOfSquared(), _stdDevScratchPad()
+ : AggregationResult(), _count(), _sum(), _sumOfSquared(), _stdDevScratchPad()
{
_stdDevScratchPad.reset(new expression::FloatResultNode());
}
-StandardDeviationAggregationResult::~StandardDeviationAggregationResult() {}
+StandardDeviationAggregationResult::~StandardDeviationAggregationResult() = default;
-const NumericResultNode& StandardDeviationAggregationResult::getStandardDeviation() const noexcept
+const NumericResultNode&
+StandardDeviationAggregationResult::getStandardDeviation() const noexcept
{
if (_count == 0) {
_stdDevScratchPad->set(Int64ResultNode(0));
@@ -520,15 +576,16 @@ const NumericResultNode& StandardDeviationAggregationResult::getStandardDeviatio
return *_stdDevScratchPad;
}
-void StandardDeviationAggregationResult::onMerge(const AggregationResult &r) {
- const StandardDeviationAggregationResult &result =
- Identifiable::cast<const StandardDeviationAggregationResult &>(r);
+void
+StandardDeviationAggregationResult::onMerge(const AggregationResult &r) {
+ const auto & result = Identifiable::cast<const StandardDeviationAggregationResult &>(r);
_count += result._count;
_sum.add(result._sum);
_sumOfSquared.add(result._sumOfSquared);
}
-void StandardDeviationAggregationResult::onAggregate(const ResultNode &result) {
+void
+StandardDeviationAggregationResult::onAggregate(const ResultNode &result) {
if (result.isMultiValue()) {
static_cast<const ResultNodeVector &>(result).flattenSum(_sum);
static_cast<const ResultNodeVector &>(result).flattenSumOfSquared(_sumOfSquared);
@@ -542,14 +599,16 @@ void StandardDeviationAggregationResult::onAggregate(const ResultNode &result) {
}
}
-void StandardDeviationAggregationResult::onReset()
+void
+StandardDeviationAggregationResult::onReset()
{
_count = 0;
_sum.set(0.0);
_sumOfSquared.set(0.0);
}
-Serializer & StandardDeviationAggregationResult::onSerialize(Serializer & os) const
+Serializer &
+StandardDeviationAggregationResult::onSerialize(Serializer & os) const
{
AggregationResult::onSerialize(os);
double sum = _sum.getFloat();
@@ -557,7 +616,8 @@ Serializer & StandardDeviationAggregationResult::onSerialize(Serializer & os) co
return os << _count << sum << sumOfSquared;
}
-Deserializer & StandardDeviationAggregationResult::onDeserialize(Deserializer & is)
+Deserializer &
+StandardDeviationAggregationResult::onDeserialize(Deserializer & is)
{
AggregationResult::onDeserialize(is);
double sum;
@@ -568,7 +628,8 @@ Deserializer & StandardDeviationAggregationResult::onDeserialize(Deserializer &
return r;
}
-void StandardDeviationAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const
+void
+StandardDeviationAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const
{
AggregationResult::visitMembers(visitor);
visit(visitor, "count", _count);
diff --git a/searchlib/src/vespa/searchlib/aggregation/aggregationresult.h b/searchlib/src/vespa/searchlib/aggregation/aggregationresult.h
index 765dcf23050..8587511497f 100644
--- a/searchlib/src/vespa/searchlib/aggregation/aggregationresult.h
+++ b/searchlib/src/vespa/searchlib/aggregation/aggregationresult.h
@@ -39,7 +39,7 @@ public:
AggregationResult & operator = (const AggregationResult &);
AggregationResult(AggregationResult &&) = default;
AggregationResult & operator = (AggregationResult &&) = default;
- ~AggregationResult();
+ ~AggregationResult() override;
class Configure : public vespalib::ObjectOperation, public vespalib::ObjectPredicate
{
private:
@@ -73,7 +73,7 @@ private:
void onPrepare(bool preserveAccurateTypes) override { (void) preserveAccurateTypes; }
bool onExecute() const override { return true; }
- void prepare() { if (getExpression() != NULL) { prepare(&getExpression()->getResult(), false); } }
+ void prepare() { if (getExpression() != nullptr) { prepare(&getExpression()->getResult(), false); } }
void prepare(const ResultNode * result, bool useForInit) { if (result) { onPrepare(*result, useForInit); } }
virtual void onPrepare(const ResultNode & result, bool useForInit) = 0;
virtual void onMerge(const AggregationResult & b) = 0;
diff --git a/searchlib/src/vespa/searchlib/aggregation/averageaggregationresult.h b/searchlib/src/vespa/searchlib/aggregation/averageaggregationresult.h
index 3d3395c63fc..96c6c34796a 100644
--- a/searchlib/src/vespa/searchlib/aggregation/averageaggregationresult.h
+++ b/searchlib/src/vespa/searchlib/aggregation/averageaggregationresult.h
@@ -12,7 +12,7 @@ public:
using NumericResultNode = expression::NumericResultNode;
DECLARE_AGGREGATIONRESULT(AverageAggregationResult);
AverageAggregationResult() : _sum(), _count(0) {}
- ~AverageAggregationResult();
+ ~AverageAggregationResult() override;
void visitMembers(vespalib::ObjectVisitor &visitor) const override;
const NumericResultNode & getAverage() const;
const NumericResultNode & getSum() const { return *_sum; }
diff --git a/searchlib/src/vespa/searchlib/aggregation/sumaggregationresult.h b/searchlib/src/vespa/searchlib/aggregation/sumaggregationresult.h
index 7309520c00d..aae77066817 100644
--- a/searchlib/src/vespa/searchlib/aggregation/sumaggregationresult.h
+++ b/searchlib/src/vespa/searchlib/aggregation/sumaggregationresult.h
@@ -2,24 +2,24 @@
#pragma once
#include "aggregationresult.h"
-#include <vespa/searchlib/expression/singleresultnode.h>
+#include <vespa/searchlib/expression/numericresultnode.h>
namespace search::aggregation {
class SumAggregationResult : public AggregationResult
{
public:
- using SingleResultNode = expression::SingleResultNode;
+ using NumericResultNode = expression::NumericResultNode;
DECLARE_AGGREGATIONRESULT(SumAggregationResult);
SumAggregationResult();
- SumAggregationResult(SingleResultNode::UP sum);
- ~SumAggregationResult();
+ SumAggregationResult(NumericResultNode::UP sum);
+ ~SumAggregationResult() override;
void visitMembers(vespalib::ObjectVisitor &visitor) const override;
- const SingleResultNode & getSum() const { return *_sum; }
+ const NumericResultNode & getSum() const { return *_sum; }
private:
const ResultNode & onGetRank() const override { return getSum(); }
void onPrepare(const ResultNode & result, bool useForInit) override;
- SingleResultNode::CP _sum;
+ NumericResultNode::CP _sum;
};
}
diff --git a/searchlib/src/vespa/searchlib/expression/numericresultnode.h b/searchlib/src/vespa/searchlib/expression/numericresultnode.h
index f14454e9403..e4c7d11b2d5 100644
--- a/searchlib/src/vespa/searchlib/expression/numericresultnode.h
+++ b/searchlib/src/vespa/searchlib/expression/numericresultnode.h
@@ -1,10 +1,9 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once
-#include <vespa/searchlib/expression/singleresultnode.h>
+#include "singleresultnode.h"
-namespace search {
-namespace expression {
+namespace search::expression {
class NumericResultNode : public SingleResultNode
{
@@ -19,5 +18,3 @@ public:
};
}
-}
-
diff --git a/searchlib/src/vespa/searchlib/expression/singleresultnode.h b/searchlib/src/vespa/searchlib/expression/singleresultnode.h
index 2417c15934b..663f6f8954f 100644
--- a/searchlib/src/vespa/searchlib/expression/singleresultnode.h
+++ b/searchlib/src/vespa/searchlib/expression/singleresultnode.h
@@ -3,8 +3,7 @@
#include "resultnode.h"
-namespace search {
-namespace expression {
+namespace search::expression {
class SingleResultNode : public ResultNode
{
@@ -26,5 +25,3 @@ public:
};
}
-}
-
diff --git a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp
index 1f554cb9af7..6eff09b65ab 100644
--- a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp
+++ b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp
@@ -7,8 +7,7 @@ using search::tensor::ITensorAttribute;
using vespalib::eval::Tensor;
using vespalib::tensor::MutableDenseTensorView;
-namespace search {
-namespace features {
+namespace search::features {
DenseTensorAttributeExecutor::
DenseTensorAttributeExecutor(const ITensorAttribute *attribute)
@@ -24,5 +23,4 @@ DenseTensorAttributeExecutor::execute(uint32_t docId)
outputs().set_object(0, _tensorView);
}
-} // namespace features
-} // namespace search
+}
diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
index dffa3bb28b5..1dcd3e35580 100644
--- a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
+++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
@@ -12,6 +12,8 @@
#include <type_traits>
#include <vespa/log/log.h>
+#include <vespa/eval/tensor/serialization/typed_binary_format.h>
+#include <vespa/vespalib/objects/nbostream.h>
LOG_SETUP(".features.dotproduct");
@@ -340,11 +342,21 @@ ArrayParam<T>::ArrayParam(const Property & prop) {
parseVectors(prop, values, indexes);
}
+template <typename T>
+ArrayParam<T>::ArrayParam(vespalib::nbostream & stream) {
+ vespalib::tensor::TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(stream, values);
+}
+
+template <typename T>
+ArrayParam<T>::~ArrayParam() = default;
+
+
// Explicit instantiation since these are inspected by unit tests.
// FIXME this feels a bit dirty, consider breaking up ArrayParam to remove dependencies
// on templated vector parsing. This is why it's defined in this translation unit as it is.
-template struct ArrayParam<int64_t>;
+template ArrayParam<int64_t>::ArrayParam(const Property & prop);
template struct ArrayParam<double>;
+template struct ArrayParam<float>;
} // namespace dotproduct
@@ -609,43 +621,63 @@ fef::Anything::UP attemptParseArrayQueryVector(const IAttributeVector & attribut
} // anon ns
+const IAttributeVector *
+DotProductBlueprint::upgradeIfNecessary(const IAttributeVector * attribute, const IQueryEnvironment & env) const {
+ if ((attribute->getCollectionType() == attribute::CollectionType::WSET) &&
+ attribute->hasEnum() &&
+ (attribute->isStringType() || attribute->isIntegerType()))
+ {
+ attribute = env.getAttributeContext().getAttributeStableEnum(getAttribute(env));
+ }
+ return attribute;
+}
+
void
DotProductBlueprint::prepareSharedState(const IQueryEnvironment & env, IObjectStore & store) const
{
_attribute = env.getAttributeContext().getAttribute(getAttribute(env));
const IAttributeVector * attribute = _attribute;
- if (attribute != nullptr) {
- if ((attribute->getCollectionType() == attribute::CollectionType::WSET) &&
- attribute->hasEnum() &&
- (attribute->isStringType() || attribute->isIntegerType()))
- {
- attribute = env.getAttributeContext().getAttributeStableEnum(getAttribute(env));
+ if (attribute == nullptr) return;
+
+ attribute = upgradeIfNecessary(attribute, env);
+ fef::Anything::UP arguments;
+ if (attribute->getCollectionType() == attribute::CollectionType::ARRAY) {
+ Property tensorBlob = env.getProperties().lookup(getBaseName(), _queryVector, "tensor");
+ if (attribute->isFloatingPointType() && tensorBlob.found() && !tensorBlob.get().empty()) {
+ const Property::Value & blob = tensorBlob.get();
+ vespalib::nbostream stream(blob.data(), blob.size());
+ if (attribute->getBasicType() == BasicType::FLOAT) {
+ arguments = std::make_unique<ArrayParam<float>>(stream);
+ } else {
+ arguments = std::make_unique<ArrayParam<double>>(stream);
+ }
+ } else {
+ Property prop = env.getProperties().lookup(getBaseName(), _queryVector);
+ if (prop.found() && !prop.get().empty()) {
+ arguments = attemptParseArrayQueryVector(*attribute, prop);
+ }
}
+ } else if (attribute->getCollectionType() == attribute::CollectionType::WSET) {
Property prop = env.getProperties().lookup(getBaseName(), _queryVector);
if (prop.found() && !prop.get().empty()) {
- fef::Anything::UP arguments;
- if (attribute->getCollectionType() == attribute::CollectionType::WSET) {
- if (attribute->isStringType() && attribute->hasEnum()) {
+ if (attribute->isStringType() && attribute->hasEnum()) {
+ dotproduct::wset::EnumVector vector(attribute);
+ WeightedSetParser::parse(prop.get(), vector);
+ } else if (attribute->isIntegerType()) {
+ if (attribute->hasEnum()) {
dotproduct::wset::EnumVector vector(attribute);
WeightedSetParser::parse(prop.get(), vector);
- } else if (attribute->isIntegerType()) {
- if (attribute->hasEnum()) {
- dotproduct::wset::EnumVector vector(attribute);
- WeightedSetParser::parse(prop.get(), vector);
- } else {
- dotproduct::wset::IntegerVector vector;
- WeightedSetParser::parse(prop.get(), vector);
- }
+ } else {
+ dotproduct::wset::IntegerVector vector;
+ WeightedSetParser::parse(prop.get(), vector);
}
- // TODO actually use the parsed output for wset operations!
- } else if (attribute->getCollectionType() == attribute::CollectionType::ARRAY) {
- arguments = attemptParseArrayQueryVector(*attribute, prop);
- }
- if (arguments.get()) {
- store.add(getBaseName() + "." + _queryVector + "." + OBJECT, std::move(arguments));
}
+ // TODO actually use the parsed output for wset operations!
}
}
+ if (arguments) {
+ store.add(getBaseName() + "." + _queryVector + "." + OBJECT, std::move(arguments));
+ }
}
FeatureExecutor &
@@ -657,12 +689,7 @@ DotProductBlueprint::createExecutor(const IQueryEnvironment & env, vespalib::Sta
getAttribute(env).c_str());
return stash.create<SingleZeroValueExecutor>();
}
- if ((attribute->getCollectionType() == attribute::CollectionType::WSET) &&
- attribute->hasEnum() &&
- (attribute->isStringType() || attribute->isIntegerType()))
- {
- attribute = env.getAttributeContext().getAttributeStableEnum(getAttribute(env));
- }
+ attribute = upgradeIfNecessary(attribute, env);
const fef::Anything * argument = env.getObjectStore().get(getBaseName() + "." + _queryVector + "." + OBJECT);
if (argument != nullptr) {
return createFromObject(attribute, *argument, stash);
diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.h b/searchlib/src/vespa/searchlib/features/dotproductfeature.h
index b6107a1a271..089066cb5f6 100644
--- a/searchlib/src/vespa/searchlib/features/dotproductfeature.h
+++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.h
@@ -10,6 +10,7 @@
#include <vespa/vespalib/stllike/hash_map.hpp>
namespace search::fef { class Property; }
+namespace vespalib { class nbostream; }
namespace search::features {
@@ -34,6 +35,9 @@ struct Converter<vespalib::string, const char *> {
template <typename T>
struct ArrayParam : public fef::Anything {
ArrayParam(const fef::Property & prop);
+ ArrayParam(vespalib::nbostream & stream);
+ ArrayParam(std::vector<T> v) : values(std::move(v)) {}
+ ~ArrayParam() override;
std::vector<T> values;
std::vector<uint32_t> indexes;
};
@@ -260,12 +264,14 @@ private:
*/
class DotProductBlueprint : public fef::Blueprint {
private:
+ using IAttributeVector = attribute::IAttributeVector;
vespalib::string _defaultAttribute;
vespalib::string _queryVector;
- mutable const attribute::IAttributeVector * _attribute;
+ mutable const IAttributeVector * _attribute;
vespalib::string getAttribute(const fef::IQueryEnvironment & env) const;
+ const IAttributeVector * upgradeIfNecessary(const IAttributeVector * attribute, const fef::IQueryEnvironment & env) const;
public:
DotProductBlueprint();
diff --git a/searchlib/src/vespa/searchlib/features/queryfeature.cpp b/searchlib/src/vespa/searchlib/features/queryfeature.cpp
index c5488581d29..eb7eb427283 100644
--- a/searchlib/src/vespa/searchlib/features/queryfeature.cpp
+++ b/searchlib/src/vespa/searchlib/features/queryfeature.cpp
@@ -3,9 +3,9 @@
#include "queryfeature.h"
#include "utils.h"
#include "valuefeature.h"
+#include "constant_tensor_executor.h"
#include <vespa/document/datatype/tensor_data_type.h>
-#include <vespa/searchlib/features/constant_tensor_executor.h>
#include <vespa/searchlib/fef/featureexecutor.h>
#include <vespa/searchlib/fef/indexproperties.h>
#include <vespa/searchlib/fef/properties.h>
@@ -25,8 +25,7 @@ using document::TensorDataType;
using vespalib::eval::ValueType;
using search::fef::FeatureType;
-namespace search {
-namespace features {
+namespace search::features {
namespace {
@@ -65,25 +64,21 @@ QueryBlueprint::QueryBlueprint() :
{
}
-QueryBlueprint::~QueryBlueprint()
-{
-}
+QueryBlueprint::~QueryBlueprint() = default;
void
-QueryBlueprint::visitDumpFeatures(const IIndexEnvironment &,
- IDumpFeatureVisitor &) const
+QueryBlueprint::visitDumpFeatures(const IIndexEnvironment &, IDumpFeatureVisitor &) const
{
}
Blueprint::UP
QueryBlueprint::createInstance() const
{
- return Blueprint::UP(new QueryBlueprint());
+ return std::make_unique<QueryBlueprint>();
}
bool
-QueryBlueprint::setup(const IIndexEnvironment &env,
- const ParameterList &params)
+QueryBlueprint::setup(const IIndexEnvironment &env, const ParameterList &params)
{
_key = params[0].getValue();
_key2 = "$";
@@ -107,19 +102,18 @@ QueryBlueprint::setup(const IIndexEnvironment &env,
FeatureType output_type = _valueType.is_tensor()
? FeatureType::object(_valueType)
: FeatureType::number();
- describeOutput("out", "The value looked up in query properties using the given key.",
- output_type);
+ describeOutput("out", "The value looked up in query properties using the given key.", output_type);
return true;
}
namespace {
FeatureExecutor &
-createTensorExecutor(const search::fef::IQueryEnvironment &env,
+createTensorExecutor(const IQueryEnvironment &env,
const vespalib::string &queryKey,
const ValueType &valueType, vespalib::Stash &stash)
{
- search::fef::Property prop = env.getProperties().lookup(queryKey);
+ Property prop = env.getProperties().lookup(queryKey);
if (prop.found() && !prop.get().empty()) {
const vespalib::string &value = prop.get();
vespalib::nbostream stream(value.data(), value.size());
@@ -156,5 +150,4 @@ QueryBlueprint::createExecutor(const IQueryEnvironment &env, vespalib::Stash &st
}
}
-} // namespace features
-} // namespace search
+}
diff --git a/searchlib/src/vespa/searchlib/queryeval/fake_result.cpp b/searchlib/src/vespa/searchlib/queryeval/fake_result.cpp
index 9786593637e..288c0f5d1d0 100644
--- a/searchlib/src/vespa/searchlib/queryeval/fake_result.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/fake_result.cpp
@@ -16,6 +16,9 @@ FakeResult::FakeResult(const FakeResult &) = default;
FakeResult::~FakeResult() = default;
+FakeResult &
+FakeResult::operator=(const FakeResult &) = default;
+
std::ostream &operator << (std::ostream &out, const FakeResult &result) {
const std::vector<FakeResult::Document> &doc = result.inspect();
if (doc.size() == 0) {
diff --git a/searchlib/src/vespa/searchlib/queryeval/fake_result.h b/searchlib/src/vespa/searchlib/queryeval/fake_result.h
index ecb7dd377b9..ddf1fa61b63 100644
--- a/searchlib/src/vespa/searchlib/queryeval/fake_result.h
+++ b/searchlib/src/vespa/searchlib/queryeval/fake_result.h
@@ -48,6 +48,7 @@ public:
FakeResult();
FakeResult(const FakeResult &);
~FakeResult();
+ FakeResult &operator=(const FakeResult &);
FakeResult &doc(uint32_t docId) {
_documents.push_back(Document(docId));
diff --git a/security-utils/pom.xml b/security-utils/pom.xml
index 10dec598915..f7704762250 100644
--- a/security-utils/pom.xml
+++ b/security-utils/pom.xml
@@ -31,6 +31,16 @@
<artifactId>jackson-databind</artifactId>
<scope>compile</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>httpclient</artifactId>
+ <scope>compile</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>httpcore</artifactId>
+ <scope>compile</scope>
+ </dependency>
<!-- test scope -->
<dependency>
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/VespaHttpClientBuilder.java b/security-utils/src/main/java/com/yahoo/security/tls/https/VespaHttpClientBuilder.java
new file mode 100644
index 00000000000..9fa51fc36cb
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/https/VespaHttpClientBuilder.java
@@ -0,0 +1,109 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls.https;
+
+import com.yahoo.security.tls.MixedMode;
+import com.yahoo.security.tls.TlsContext;
+import com.yahoo.security.tls.TransportSecurityUtils;
+import org.apache.http.HttpRequest;
+import org.apache.http.HttpRequestInterceptor;
+import org.apache.http.client.methods.HttpRequestBase;
+import org.apache.http.client.utils.URIBuilder;
+import org.apache.http.conn.HttpClientConnectionManager;
+import org.apache.http.conn.ssl.NoopHostnameVerifier;
+import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
+import org.apache.http.impl.client.HttpClientBuilder;
+import org.apache.http.protocol.HttpContext;
+
+import javax.net.ssl.SSLParameters;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * Http client builder for internal Vespa communications over http/https.
+ *
+ * Notes:
+ * - hostname verification is not enabled - CN/SAN verification is assumed to be handled by the underlying x509 trust manager.
+ * - custom connection managers must be configured through {@link #createBuilder(ConnectionManagerFactory)}. Do not call {@link HttpClientBuilder#setConnectionManager(HttpClientConnectionManager)}.
+ *
+ * @author bjorncs
+ */
+public class VespaHttpClientBuilder {
+
+ private static final Logger log = Logger.getLogger(VespaHttpClientBuilder.class.getName());
+
+ public interface ConnectionManagerFactory {
+ HttpClientConnectionManager create(SSLConnectionSocketFactory sslSocketFactory);
+ }
+
+ private VespaHttpClientBuilder() {}
+
+ public static HttpClientBuilder create() {
+ return createBuilder(null);
+ }
+
+ public static HttpClientBuilder create(ConnectionManagerFactory connectionManagerFactory) {
+ return createBuilder(connectionManagerFactory);
+ }
+
+ private static HttpClientBuilder createBuilder(ConnectionManagerFactory connectionManagerFactory) {
+ var builder = HttpClientBuilder.create();
+ addSslSocketFactory(builder, connectionManagerFactory);
+ addTlsAwareRequestInterceptor(builder);
+ return builder;
+ }
+
+ private static void addSslSocketFactory(HttpClientBuilder builder, ConnectionManagerFactory connectionManagerFactory) {
+ TransportSecurityUtils.createTlsContext()
+ .ifPresent(tlsContext -> {
+ log.log(Level.FINE, "Adding ssl socket factory to client");
+ SSLConnectionSocketFactory socketFactory = createSslSocketFactory(tlsContext);
+ if (connectionManagerFactory != null) {
+ builder.setConnectionManager(connectionManagerFactory.create(socketFactory));
+ } else {
+ builder.setSSLSocketFactory(socketFactory);
+ }
+ });
+ }
+
+ private static void addTlsAwareRequestInterceptor(HttpClientBuilder builder) {
+ if (TransportSecurityUtils.isTransportSecurityEnabled()
+ && TransportSecurityUtils.getInsecureMixedMode() != MixedMode.PLAINTEXT_CLIENT_MIXED_SERVER) {
+ log.log(Level.FINE, "Adding request interceptor to client");
+ builder.addInterceptorFirst(new HttpToHttpsRewritingRequestInterceptor());
+ }
+ }
+
+ private static SSLConnectionSocketFactory createSslSocketFactory(TlsContext tlsContext) {
+ SSLParameters parameters = tlsContext.parameters();
+ return new SSLConnectionSocketFactory(tlsContext.context(), parameters.getProtocols(), parameters.getCipherSuites(), new NoopHostnameVerifier());
+ }
+
+ static class HttpToHttpsRewritingRequestInterceptor implements HttpRequestInterceptor {
+ @Override
+ public void process(HttpRequest request, HttpContext context) {
+ if (request instanceof HttpRequestBase) {
+ HttpRequestBase httpUriRequest = (HttpRequestBase) request;
+ httpUriRequest.setURI(rewriteUri(httpUriRequest.getURI()));
+ } else {
+ log.log(Level.FINE, () -> "Not a HttpRequestBase - skipping URI rewriting: " + request.getClass().getName());
+ }
+ }
+
+ private static URI rewriteUri(URI originalUri) {
+ if (!originalUri.getScheme().equals("http")) {
+ return originalUri;
+ }
+ int port = originalUri.getPort();
+ int rewrittenPort = port != -1 ? port : 80;
+ try {
+ URI rewrittenUri = new URIBuilder(originalUri).setScheme("https").setPort(rewrittenPort).build();
+ log.log(Level.FINE, () -> String.format("Uri rewritten from '%s' to '%s'", originalUri, rewrittenUri));
+ return rewrittenUri;
+ } catch (URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+}
diff --git a/security-utils/src/test/java/com/yahoo/security/tls/https/VespaHttpClientBuilderTest.java b/security-utils/src/test/java/com/yahoo/security/tls/https/VespaHttpClientBuilderTest.java
new file mode 100644
index 00000000000..10b8458359c
--- /dev/null
+++ b/security-utils/src/test/java/com/yahoo/security/tls/https/VespaHttpClientBuilderTest.java
@@ -0,0 +1,39 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls.https;
+
+import com.yahoo.security.tls.https.VespaHttpClientBuilder.HttpToHttpsRewritingRequestInterceptor;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.protocol.BasicHttpContext;
+import org.junit.Test;
+
+import java.net.URI;
+
+import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
+
+/**
+ * @author bjorncs
+ */
+public class VespaHttpClientBuilderTest {
+
+ @Test
+ public void request_interceptor_modifies_scheme_of_requests() {
+ verifyProcessedUriMatchesExpectedOutput("http://dummyhostname:8080/a/path/to/resource?query=value",
+ "https://dummyhostname:8080/a/path/to/resource?query=value");
+ }
+
+ @Test
+ public void request_interceptor_add_handles_implicit_http_port() {
+ verifyProcessedUriMatchesExpectedOutput("http://dummyhostname/a/path/to/resource?query=value",
+ "https://dummyhostname:80/a/path/to/resource?query=value");
+ }
+
+ private static void verifyProcessedUriMatchesExpectedOutput(String inputUri, String expectedOutputUri) {
+ var interceptor = new HttpToHttpsRewritingRequestInterceptor();
+ HttpGet request = new HttpGet(inputUri);
+ interceptor.process(request, new BasicHttpContext());
+ URI modifiedUri = request.getURI();
+ URI expectedUri = URI.create(expectedOutputUri);
+ assertThat(modifiedUri).isEqualTo(expectedUri);
+ }
+
+} \ No newline at end of file
diff --git a/vespa-documentgen-plugin/etc/complex/music3.sd b/vespa-documentgen-plugin/etc/complex/music3.sd
index 65f37029d04..45ce11fd581 100644
--- a/vespa-documentgen-plugin/etc/complex/music3.sd
+++ b/vespa-documentgen-plugin/etc/complex/music3.sd
@@ -4,5 +4,8 @@ search music3 {
field mu3 type string {
}
+ field pos type position {
+
+ }
}
}
diff --git a/vespa-documentgen-plugin/src/main/java/com/yahoo/vespa/DocumentGenMojo.java b/vespa-documentgen-plugin/src/main/java/com/yahoo/vespa/DocumentGenMojo.java
index 7e73d6b5915..bc34a4ac3df 100644
--- a/vespa-documentgen-plugin/src/main/java/com/yahoo/vespa/DocumentGenMojo.java
+++ b/vespa-documentgen-plugin/src/main/java/com/yahoo/vespa/DocumentGenMojo.java
@@ -3,11 +3,14 @@ package com.yahoo.vespa;
import com.yahoo.collections.Pair;
import com.yahoo.document.ArrayDataType;
+import com.yahoo.document.CollectionDataType;
import com.yahoo.document.DataType;
import com.yahoo.document.Field;
import com.yahoo.document.MapDataType;
+import com.yahoo.document.PositionDataType;
import com.yahoo.document.ReferenceDataType;
import com.yahoo.document.StructDataType;
+import com.yahoo.document.StructuredDataType;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.WeightedSetDataType;
import com.yahoo.document.annotation.AnnotationReferenceDataType;
@@ -18,7 +21,6 @@ import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.parser.ParseException;
import org.apache.maven.plugin.AbstractMojo;
-import org.apache.maven.plugin.MojoFailureException;
import org.apache.maven.plugins.annotations.Component;
import org.apache.maven.plugins.annotations.LifecyclePhase;
import org.apache.maven.plugins.annotations.Mojo;
@@ -31,6 +33,7 @@ import java.io.FilenameFilter;
import java.io.IOException;
import java.io.Writer;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collection;
import java.util.Date;
import java.util.HashMap;
@@ -468,6 +471,9 @@ public class DocumentGenMojo extends AbstractMojo {
exportHashCode(allUniqueFields, out, 1, "(getDataType() != null ? getDataType().hashCode() : 0) + getId().hashCode()");
exportEquals(className, allUniqueFields, out, 1);
Set<DataType> exportedStructs = exportStructTypes(docType.getTypes(), out, 1, null);
+ if (hasAnyPositionField(allUniqueFields)) {
+ exportedStructs = exportStructTypes(Arrays.asList(PositionDataType.INSTANCE), out, 1, exportedStructs);
+ }
docTypes.put(docType.getName(), packageName+"."+className);
for (DataType exportedStruct : exportedStructs) {
structTypes.put(exportedStruct.getName(), packageName+"."+className+"."+className(exportedStruct.getName()));
@@ -475,6 +481,25 @@ public class DocumentGenMojo extends AbstractMojo {
out.write("}\n");
}
+ private static boolean hasAnyPostionDataType(DataType dt) {
+ if (dt instanceof CollectionDataType) {
+ return hasAnyPostionDataType(((CollectionDataType)dt).getNestedType());
+ } else if (dt instanceof StructuredDataType) {
+ return hasAnyPositionField(((StructuredDataType)dt).getFields());
+ } else {
+ return PositionDataType.INSTANCE.equals(dt);
+ }
+ }
+
+ private static boolean hasAnyPositionField(Collection<Field> fields) {
+ for (Field f : fields) {
+ if (hasAnyPostionDataType(f.getDataType())) {
+ return true;
+ }
+ }
+ return true;
+ }
+
private Collection<Field> getAllUniqueFields(Boolean multipleInheritance, Collection<Field> allFields) {
if (multipleInheritance) {
Map<String, Field> seen = new HashMap<>();
@@ -732,7 +757,8 @@ public class DocumentGenMojo extends AbstractMojo {
ind(ind)+" * Input struct type: "+structType.getName()+"\n" +
ind(ind)+" * Date: "+new Date()+"\n" +
ind(ind)+" */\n" +
- ind(ind)+"@com.yahoo.document.Generated public static class "+structClassName+" extends com.yahoo.document.datatypes.Struct {\n\n" +
+ ind(ind)+"@com.yahoo.document.Generated\n" +
+ ind(ind) + "public static class "+structClassName+" extends com.yahoo.document.datatypes.Struct {\n\n" +
ind(ind+1)+"/** The type of this.*/\n" +
ind(ind+1)+"public static final com.yahoo.document.StructDataType type = getStructType();\n\n");
out.write(ind(ind+1)+"public "+structClassName+"() {\n" +
diff --git a/vespa-documentgen-plugin/src/test/java/com/yahoo/vespa/DocumentGenTest.java b/vespa-documentgen-plugin/src/test/java/com/yahoo/vespa/DocumentGenTest.java
index b21f38c586a..c195e116bf0 100644
--- a/vespa-documentgen-plugin/src/test/java/com/yahoo/vespa/DocumentGenTest.java
+++ b/vespa-documentgen-plugin/src/test/java/com/yahoo/vespa/DocumentGenTest.java
@@ -5,8 +5,6 @@ import com.yahoo.document.DataType;
import com.yahoo.document.StructDataType;
import com.yahoo.document.WeightedSetDataType;
import com.yahoo.searchdefinition.Search;
-import org.apache.maven.plugin.MojoExecutionException;
-import org.apache.maven.plugin.MojoFailureException;
import org.junit.Test;
import java.io.File;
@@ -19,7 +17,7 @@ import static org.junit.Assert.fail;
public class DocumentGenTest {
@Test
- public void testMusic() throws MojoExecutionException, MojoFailureException {
+ public void testMusic() {
DocumentGenMojo mojo = new DocumentGenMojo();
mojo.execute(new File("etc/music/"), new File("target/generated-test-sources/vespa-documentgen-plugin/"), "com.yahoo.vespa.document");
Map<String, Search> searches = mojo.getSearches();
@@ -28,19 +26,21 @@ public class DocumentGenTest {
}
@Test
- public void testComplex() throws MojoFailureException {
+ public void testComplex() {
DocumentGenMojo mojo = new DocumentGenMojo();
mojo.execute(new File("etc/complex/"), new File("target/generated-test-sources/vespa-documentgen-plugin/"), "com.yahoo.vespa.document");
Map<String, Search> searches = mojo.getSearches();
assertEquals(searches.get("video").getDocument("video").getField("weight").getDataType(), DataType.FLOAT);
assertEquals(searches.get("book").getDocument("book").getField("sw1").getDataType(), DataType.FLOAT);
+ assertTrue(searches.get("music3").getDocument("music3").getField("pos").getDataType() instanceof StructDataType);
+ assertEquals(searches.get("music3").getDocument("music3").getField("pos").getDataType().getName(), "position");
assertTrue(searches.get("book").getDocument("book").getField("mystruct").getDataType() instanceof StructDataType);
assertTrue(searches.get("book").getDocument("book").getField("mywsfloat").getDataType() instanceof WeightedSetDataType);
assertTrue(((WeightedSetDataType)(searches.get("book").getDocument("book").getField("mywsfloat").getDataType())).getNestedType() == DataType.FLOAT);
}
@Test
- public void testLocalApp() throws MojoFailureException {
+ public void testLocalApp() {
DocumentGenMojo mojo = new DocumentGenMojo();
mojo.execute(new File("etc/localapp/"), new File("target/generated-test-sources/vespa-documentgen-plugin/"), "com.yahoo.vespa.document");
Map<String, Search> searches = mojo.getSearches();
@@ -51,7 +51,7 @@ public class DocumentGenTest {
}
@Test
- public void testEmptyPkgNameForbidden() throws MojoFailureException {
+ public void testEmptyPkgNameForbidden() {
DocumentGenMojo mojo = new DocumentGenMojo();
try {
mojo.execute(new File("etc/localapp/"), new File("target/generated-test-sources/vespa-documentgen-plugin/"), "");
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 239efa0f89c..43388e4e18d 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -947,7 +947,7 @@
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
"public long denseSubspaceSize()",
- "public static com.yahoo.tensor.TensorType createPartialType(java.util.List)"
+ "public static com.yahoo.tensor.TensorType createPartialType(com.yahoo.tensor.TensorType$Value, java.util.List)"
],
"fields": []
},
@@ -1162,11 +1162,10 @@
],
"methods": [
"public void <init>()",
- "public void <init>(com.yahoo.tensor.TensorType$ValueType)",
+ "public void <init>(com.yahoo.tensor.TensorType$Value)",
"public varargs void <init>(com.yahoo.tensor.TensorType[])",
- "public varargs void <init>(com.yahoo.tensor.TensorType$ValueType, com.yahoo.tensor.TensorType[])",
"public void <init>(java.lang.Iterable)",
- "public void <init>(com.yahoo.tensor.TensorType$ValueType, java.lang.Iterable)",
+ "public void <init>(com.yahoo.tensor.TensorType$Value, java.lang.Iterable)",
"public int rank()",
"public com.yahoo.tensor.TensorType$Builder set(com.yahoo.tensor.TensorType$Dimension)",
"public com.yahoo.tensor.TensorType$Builder indexed(java.lang.String, long)",
@@ -1270,7 +1269,7 @@
],
"fields": []
},
- "com.yahoo.tensor.TensorType$ValueType": {
+ "com.yahoo.tensor.TensorType$Value": {
"superClass": "java.lang.Enum",
"interfaces": [],
"attributes": [
@@ -1279,12 +1278,14 @@
"enum"
],
"methods": [
- "public static com.yahoo.tensor.TensorType$ValueType[] values()",
- "public static com.yahoo.tensor.TensorType$ValueType valueOf(java.lang.String)"
+ "public static com.yahoo.tensor.TensorType$Value[] values()",
+ "public static com.yahoo.tensor.TensorType$Value valueOf(java.lang.String)",
+ "public static com.yahoo.tensor.TensorType$Value largestOf(java.util.List)",
+ "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)"
],
"fields": [
- "public static final enum com.yahoo.tensor.TensorType$ValueType DOUBLE",
- "public static final enum com.yahoo.tensor.TensorType$ValueType FLOAT"
+ "public static final enum com.yahoo.tensor.TensorType$Value DOUBLE",
+ "public static final enum com.yahoo.tensor.TensorType$Value FLOAT"
]
},
"com.yahoo.tensor.TensorType": {
@@ -1294,9 +1295,8 @@
"public"
],
"methods": [
- "public final com.yahoo.tensor.TensorType$ValueType valueType()",
- "public final com.yahoo.tensor.TensorType valueType(com.yahoo.tensor.TensorType$ValueType)",
"public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)",
+ "public com.yahoo.tensor.TensorType$Value valueType()",
"public int rank()",
"public java.util.List dimensions()",
"public java.util.Set dimensionNames()",
@@ -1325,7 +1325,7 @@
"methods": [
"public void <init>()",
"public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)",
- "public static java.util.List dimensionsFromSpec(java.lang.String)"
+ "public static com.yahoo.tensor.TensorType$Value toValueType(java.lang.String)"
],
"fields": []
},
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 08878edeb83..c06cb2a0986 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -319,7 +319,7 @@ public class MixedTensor implements Tensor {
}
public TensorType createBoundType() {
- TensorType.Builder typeBuilder = new TensorType.Builder();
+ TensorType.Builder typeBuilder = new TensorType.Builder(type().valueType());
for (int i = 0; i < type.dimensions().size(); ++i) {
TensorType.Dimension dimension = type.dimensions().get(i);
if (!dimension.isIndexed()) {
@@ -355,8 +355,8 @@ public class MixedTensor implements Tensor {
this.type = type;
this.mappedDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
this.indexedDimensions = type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList());
- this.sparseType = createPartialType(mappedDimensions);
- this.denseType = createPartialType(indexedDimensions);
+ this.sparseType = createPartialType(type.valueType(), mappedDimensions);
+ this.denseType = createPartialType(type.valueType(), indexedDimensions);
}
public long indexOf(TensorAddress address) {
@@ -476,8 +476,8 @@ public class MixedTensor implements Tensor {
}
- public static TensorType createPartialType(List<TensorType.Dimension> dimensions) {
- TensorType.Builder builder = new TensorType.Builder();
+ public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) {
+ TensorType.Builder builder = new TensorType.Builder(valueType);
for (TensorType.Dimension dimension : dimensions) {
builder.set(dimension);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index fa32d385004..45a9992c9ad 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -11,14 +11,14 @@ class TensorParser {
static Tensor tensorFrom(String tensorString, Optional<TensorType> type) {
tensorString = tensorString.trim();
try {
- if (tensorString.startsWith("tensor(")) {
+ if (tensorString.startsWith("tensor")) {
int colonIndex = tensorString.indexOf(':');
String typeString = tensorString.substring(0, colonIndex);
String valueString = tensorString.substring(colonIndex + 1);
TensorType typeFromString = TensorTypeParser.fromSpec(typeString);
if (type.isPresent() && ! type.get().equals(typeFromString))
throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " +
- "passed type " + type);
+ "passed type " + type.get());
return tensorFromValueString(valueString, typeFromString);
}
else if (tensorString.startsWith("{")) {
@@ -48,7 +48,7 @@ class TensorParser {
addressBody = addressBody.substring(1); // remove key start
if (addressBody.isEmpty()) return TensorType.empty; // Empty key
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(TensorType.Value.DOUBLE);
for (String elementString : addressBody.split(",")) {
String[] pair = elementString.split(":");
if (pair.length != 2)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 036f5e3ee5d..df78f3dfc3a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -4,6 +4,7 @@ package com.yahoo.tensor;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
@@ -24,25 +25,40 @@ import java.util.stream.Collectors;
*/
public class TensorType {
- public enum ValueType { DOUBLE, FLOAT};
+ /** The permissible cell value types. Default is double. */
+ public enum Value {
- /** The empty tensor type - which is the same as a double */
- public static final TensorType empty = new TensorType(ValueType.DOUBLE, Collections.emptyList());
+ // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below
+ DOUBLE, FLOAT;
- private ValueType valueType;
+ public static Value largestOf(List<Value> values) {
+ if (values.isEmpty()) return Value.DOUBLE; // Default
+ Value largest = null;
+ for (Value value : values) {
+ if (largest == null)
+ largest = value;
+ else
+ largest = largestOf(largest, value);
+ }
+ return largest;
+ }
- public final ValueType valueType() { return valueType; }
+ public static Value largestOf(Value value1, Value value2) {
+ if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE;
+ return FLOAT;
+ }
- //TODO Remove once value type is wired in were it should.
- public final TensorType valueType(ValueType valueType) {
- this.valueType = valueType;
- return this;
- }
+ };
+
+ /** The empty tensor type - which is the same as a double */
+ public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList());
+
+ private final Value valueType;
/** Sorted list of the dimensions of this */
private final ImmutableList<Dimension> dimensions;
- private TensorType(ValueType valueType, Collection<Dimension> dimensions) {
+ private TensorType(Value valueType, Collection<Dimension> dimensions) {
this.valueType = valueType;
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
@@ -64,6 +80,9 @@ public class TensorType {
return TensorTypeParser.fromSpec(specString);
}
+ /** Returns the numeric type of the cell values of this */
+ public Value valueType() { return valueType; }
+
/** Returns the number of dimensions of this: dimensions().size() */
public int rank() { return dimensions.size(); }
@@ -149,10 +168,14 @@ public class TensorType {
}
@Override
- public boolean equals(Object other) {
- if (this == other) return true;
- if (other == null || getClass() != other.getClass()) return false;
- return dimensions.equals(((TensorType)other).dimensions);
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ TensorType other = (TensorType)o;
+ if ( this.valueType != other.valueType) return false;
+ if ( ! this.dimensions.equals(other.dimensions)) return false;
+ return true;
}
/** Returns whether the given type has the same dimension names as this */
@@ -173,7 +196,7 @@ public class TensorType {
if (this.equals(other)) return Optional.of(this); // shortcut
if (this.dimensions.size() != other.dimensions.size()) return Optional.empty();
- Builder b = new Builder();
+ Builder b = new Builder(TensorType.Value.largestOf(valueType, other.valueType));
for (int i = 0; i < dimensions.size(); i++) {
Dimension thisDim = this.dimensions().get(i);
Dimension otherDim = other.dimensions().get(i);
@@ -386,14 +409,14 @@ public class TensorType {
private final Map<String, Dimension> dimensions = new LinkedHashMap<>();
- private final ValueType valueType;
+ private final Value valueType;
- /** Creates an empty builder with cells of type double*/
+ /** Creates an empty builder with cells of type double */
public Builder() {
- this(ValueType.DOUBLE);
+ this(Value.DOUBLE);
}
- public Builder(ValueType valueType) {
+ public Builder(Value valueType) {
this.valueType = valueType;
}
@@ -403,23 +426,22 @@ public class TensorType {
* If the same dimension is indexed with different size restrictions the largest size will be used.
* If it is size restricted in one argument but not the other it will not be size restricted.
* If it is indexed in one and mapped in the other it will become mapped.
+ *
+ * The value type will be the largest of the value types of the input types
*/
public Builder(TensorType ... types) {
- this(ValueType.DOUBLE, types);
- }
- public Builder(ValueType valueType, TensorType ... types) {
- this.valueType = valueType;
+ this.valueType = TensorType.Value.largestOf(Arrays.stream(types).map(type -> type.valueType()).collect(Collectors.toList()));
for (TensorType type : types)
addDimensionsOf(type);
}
- /**
- * Creates a builder from the given dimensions.
- */
+ /** Creates a builder from the given dimensions, having double as the value type */
public Builder(Iterable<Dimension> dimensions) {
- this(ValueType.DOUBLE, dimensions);
+ this(Value.DOUBLE, dimensions);
}
- public Builder(ValueType valueType, Iterable<Dimension> dimensions) {
+
+ /** Creates a builder from the given value type and dimensions */
+ public Builder(Value valueType, Iterable<Dimension> dimensions) {
this.valueType = valueType;
for (TensorType.Dimension dimension : dimensions) {
dimension(dimension);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index 32ad6171e57..d5f77be0dd0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -2,8 +2,10 @@
package com.yahoo.tensor;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -11,26 +13,36 @@ import java.util.regex.Pattern;
* Class for parsing a tensor type spec.
*
* @author geirst
+ * @author bratseth
*/
public class TensorTypeParser {
- private final static String START_STRING = "tensor(";
+ private final static String START_STRING = "tensor";
private final static String END_STRING = ")";
private static final Pattern indexedPattern = Pattern.compile("(\\w+)\\[(\\d*)\\]");
private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}");
public static TensorType fromSpec(String specString) {
- return new TensorType.Builder(dimensionsFromSpec(specString)).build();
- }
+ if ( ! specString.startsWith(START_STRING) || ! specString.endsWith(END_STRING))
+ throw formatException(specString);
+ String specBody = specString.substring(START_STRING.length(), specString.length() - END_STRING.length());
- public static List<TensorType.Dimension> dimensionsFromSpec(String specString) {
- if ( ! specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) {
- throw new IllegalArgumentException("Tensor type spec must start with '" + START_STRING + "'" +
- " and end with '" + END_STRING + "', but was '" + specString + "'");
+ String dimensionsSpec;
+ TensorType.Value valueType;
+ if (specBody.startsWith("(")) {
+ valueType = TensorType.Value.DOUBLE; // no value type spec: Use default
+ dimensionsSpec = specBody.substring(1);
+ }
+ else {
+ int parenthesisIndex = specBody.indexOf("(");
+ if (parenthesisIndex < 0)
+ throw formatException(specString);
+ valueType = parseValueTypeSpec(specBody.substring(0, parenthesisIndex), specString);
+ dimensionsSpec = specBody.substring(parenthesisIndex + 1);
}
- String dimensionsSpec = specString.substring(START_STRING.length(), specString.length() - END_STRING.length());
- if (dimensionsSpec.isEmpty()) return Collections.emptyList();
+
+ if (dimensionsSpec.isEmpty()) return new TensorType.Builder(valueType, Collections.emptyList()).build();
List<TensorType.Dimension> dimensions = new ArrayList<>();
for (String element : dimensionsSpec.split(",")) {
@@ -38,10 +50,30 @@ public class TensorTypeParser {
boolean success = tryParseIndexedDimension(trimmedElement, dimensions) ||
tryParseMappedDimension(trimmedElement, dimensions);
if ( ! success)
- throw new IllegalArgumentException("Failed parsing element '" + element +
- "' in type spec '" + specString + "'");
+ throw formatException(specString, "Dimension '" + element + "' is on the wrong format");
+ }
+ return new TensorType.Builder(valueType, dimensions).build();
+ }
+
+ public static TensorType.Value toValueType(String valueTypeString) {
+ switch (valueTypeString) {
+ case "double" : return TensorType.Value.DOUBLE;
+ case "float" : return TensorType.Value.FLOAT;
+ default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
+ " but was '" + valueTypeString + "'");
+ }
+ }
+
+ private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) {
+ if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">"))
+ throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>"));
+
+ try {
+ return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
+ }
+ catch (IllegalArgumentException e) {
+ throw formatException(fullSpecString, e.getMessage());
}
- return dimensions;
}
private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) {
@@ -69,5 +101,21 @@ public class TensorTypeParser {
return false;
}
+
+ private static IllegalArgumentException formatException(String spec) {
+ return formatException(spec, Optional.empty());
+ }
+
+ private static IllegalArgumentException formatException(String spec, String errorDetail) {
+ return formatException(spec, Optional.of(errorDetail));
+ }
+
+ private static IllegalArgumentException formatException(String spec, Optional<String> errorDetail) {
+ throw new IllegalArgumentException("A tensor type spec must be on the form " +
+ "tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was '" + spec + "'. " +
+ errorDetail.map(s -> s + ". ").orElse("") +
+ "Examples: tensor(x[]), tensor<float>(name{}, x[10])");
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 91ab4f9d046..a48ac19fbff 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -73,8 +73,8 @@ public class Concat extends PrimitiveTensorFunction {
MutableLong concatSize = new MutableLong(0);
a.sizeOfDimension(dimension).ifPresent(concatSize::add);
b.sizeOfDimension(dimension).ifPresent(concatSize::add);
- builder.set(TensorType.Dimension.indexed(dimension, concatSize.get()));
- */
+ builder.set(TensorType.Dimension.indexed(dimension, concatSize.get()));
+ */
}
return builder.build();
}
@@ -141,7 +141,11 @@ public class Concat extends PrimitiveTensorFunction {
if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
throw new IllegalArgumentException("Concat requires an indexed tensor, " +
"but got a tensor with type " + tensor.type());
- Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build();
+ Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType())
+ .indexed(dimensionName, 1)
+ .build())
+ .cell(1,0)
+ .build();
return tensor.multiply(unitTensor);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 62ee471fcf4..062e0d92e80 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -386,13 +386,12 @@ public class Join extends PrimitiveTensorFunction {
return true;
}
- /**
- * Returns common dimension of a and b as a new tensor type
- */
+ /** Returns common dimension of a and b as a new tensor type */
private static TensorType commonDimensions(Tensor a, Tensor b) {
- TensorType.Builder typeBuilder = new TensorType.Builder();
TensorType aType = a.type();
TensorType bType = b.type();
+ TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.Value.largestOf(aType.valueType(),
+ bType.valueType()));
for (int i = 0; i < aType.dimensions().size(); ++i) {
TensorType.Dimension aDim = aType.dimensions().get(i);
for (int j = 0; j < bType.dimensions().size(); ++j) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 54d7710c9dc..017dc3920e6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -61,8 +61,8 @@ public class Reduce extends PrimitiveTensorFunction {
}
public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) {
- if (reduceDimensions.isEmpty()) return TensorType.empty; // means reduce all
- TensorType.Builder b = new TensorType.Builder();
+ TensorType.Builder b = new TensorType.Builder(inputType.valueType());
+ if (reduceDimensions.isEmpty()) return b.build(); // means reduce all
for (TensorType.Dimension dimension : inputType.dimensions()) {
if ( ! reduceDimensions.contains(dimension.name()))
b.dimension(dimension);
@@ -109,8 +109,8 @@ public class Reduce extends PrimitiveTensorFunction {
}
private static TensorType type(TensorType argumentType, List<String> dimensions) {
- if (dimensions.isEmpty()) return TensorType.empty; // means reduce all
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(argumentType.valueType());
+ if (dimensions.isEmpty()) return builder.build(); // means reduce all
for (TensorType.Dimension dimension : argumentType.dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
builder.dimension(dimension);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index b268e33b418..db950e6c8b9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -268,7 +268,8 @@ public class ReduceJoin extends CompositeTensorFunction {
}
private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(TensorType.Value.largestOf(a.type().valueType(),
+ b.type().valueType()));
for (TensorType.Dimension aDim : a.type().dimensions()) {
for (TensorType.Dimension bDim : b.type().dimensions()) {
if (aDim.name().equals(bDim.name())) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index e18af235d59..5694684956e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -75,7 +75,7 @@ public class Rename extends PrimitiveTensorFunction {
}
private TensorType type(TensorType type) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(type.valueType());
for (TensorType.Dimension dimension : type.dimensions())
builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
return builder.build();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 500c436516f..ecd4f7d1965 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -43,7 +43,7 @@ public class DenseBinaryFormat implements BinaryFormat {
encodeCells(buffer, tensor);
}
- private void encodeValueType(GrowableByteBuffer buffer, TensorType.ValueType valueType) {
+ private void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) {
switch (valueType) {
case DOUBLE:
if (encodeType != EncodeType.DOUBLE_IS_DEFAULT) {
@@ -100,7 +100,7 @@ public class DenseBinaryFormat implements BinaryFormat {
sizes = sizesFromType(serializedType);
}
else {
- type = decodeType(buffer, TensorType.ValueType.DOUBLE);
+ type = decodeType(buffer, TensorType.Value.DOUBLE);
sizes = sizesFromType(type);
}
Tensor.Builder builder = Tensor.Builder.of(type, sizes);
@@ -108,16 +108,16 @@ public class DenseBinaryFormat implements BinaryFormat {
return builder.build();
}
- private TensorType decodeType(GrowableByteBuffer buffer, TensorType.ValueType valueType) {
- TensorType.ValueType serializedValueType = TensorType.ValueType.DOUBLE;
- if ((valueType != TensorType.ValueType.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) {
+ private TensorType decodeType(GrowableByteBuffer buffer, TensorType.Value valueType) {
+ TensorType.Value serializedValueType = TensorType.Value.DOUBLE;
+ if ((valueType != TensorType.Value.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) {
int type = buffer.getInt1_4Bytes();
switch (type) {
case DOUBLE_VALUE_TYPE:
- serializedValueType = TensorType.ValueType.DOUBLE;
+ serializedValueType = TensorType.Value.DOUBLE;
break;
case FLOAT_VALUE_TYPE:
- serializedValueType = TensorType.ValueType.FLOAT;
+ serializedValueType = TensorType.Value.FLOAT;
break;
default:
throw new IllegalArgumentException("Received tensor value type '" + serializedValueType + "'. Only 0(double), or 1(float) are legal.");
@@ -141,7 +141,7 @@ public class DenseBinaryFormat implements BinaryFormat {
return builder.build();
}
- private void decodeCells(TensorType.ValueType valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
+ private void decodeCells(TensorType.Value valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
switch (valueType) {
case DOUBLE:
decodeCellsAsDouble(sizes, buffer, builder);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
index acaeb3ef5ba..284dfea2141 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
@@ -78,7 +78,7 @@ class MixedBinaryFormat implements BinaryFormat {
TensorType serializedType = decodeType(buffer);
if ( ! serializedType.isAssignableTo(type))
throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
- " cannot be assigned to type " + type);
+ " cannot be assigned to type " + type);
}
else {
type = decodeType(buffer);
@@ -103,7 +103,7 @@ class MixedBinaryFormat implements BinaryFormat {
private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) {
List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
- TensorType sparseType = MixedTensor.createPartialType(sparseDimensions);
+ TensorType sparseType = MixedTensor.createPartialType(type.valueType(), sparseDimensions);
long denseSubspaceSize = builder.denseSubspaceSize();
int numBlocks = 1;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index f2c5d4e2bd8..9b298f1dffb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -23,7 +23,9 @@ public class TypedBinaryFormat {
private static final int SPARSE_BINARY_FORMAT_TYPE = 1;
private static final int DENSE_BINARY_FORMAT_TYPE = 2;
private static final int MIXED_BINARY_FORMAT_TYPE = 3;
- private static final int TYPED_DENSE_BINARY_FORMAT_TYPE = 4;
+ private static final int SPARSE_BINARY_FORMAT_WITH_CELLTYPE = 5;
+ private static final int DENSE_BINARY_FORMAT_WITH_CELLTYPE = 6;
+ private static final int MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7;
public static byte[] encode(Tensor tensor) {
GrowableByteBuffer buffer = new GrowableByteBuffer();
@@ -38,7 +40,7 @@ public class TypedBinaryFormat {
new DenseBinaryFormat(DenseBinaryFormat.EncodeType.DOUBLE_IS_DEFAULT).encode(buffer, tensor);
break;
default:
- buffer.putInt1_4Bytes(TYPED_DENSE_BINARY_FORMAT_TYPE);
+ buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_WITH_CELLTYPE);
new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).encode(buffer, tensor);
break;
}
@@ -67,7 +69,7 @@ public class TypedBinaryFormat {
case MIXED_BINARY_FORMAT_TYPE: return new MixedBinaryFormat().decode(type, buffer);
case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer);
case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.DOUBLE_IS_DEFAULT).decode(type, buffer);
- case TYPED_DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).decode(type, buffer);
+ case DENSE_BINARY_FORMAT_WITH_CELLTYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).decode(type, buffer);
default: throw new IllegalArgumentException("Binary format type " + formatType + " is unknown");
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java b/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java
index 9602bdb8d94..f6fed9d33ed 100644
--- a/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java
@@ -69,16 +69,6 @@ public class BoundingBoxParserTestCase {
all1234(parser);
}
- /**
- * Tests various legal inputs and print the output
- */
- @Test
- public void testPrint() {
- String here = "n=63.418417 E=10.433033 S=37.7 W=-122.02";
- parser = new BoundingBoxParser(here);
- System.out.println(here+" -> "+parser);
- }
-
@Test
public void testGeoPlanetExample() {
/* example XML:
diff --git a/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java
index e8ceab44c78..7cf4bddaa01 100644
--- a/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java
@@ -57,7 +57,6 @@ public class BinaryFormatTestCase {
@Test
public void testZigZagConversion() {
- System.out.println("test zigzag conversion");
assertThat(encode_zigzag(0), is((long)0));
assertThat(decode_zigzag(encode_zigzag(0)), is(0L));
@@ -88,7 +87,6 @@ public class BinaryFormatTestCase {
@Test
public void testDoubleConversion() {
- System.out.println("test double conversion");
assertThat(encode_double(0.0), is(0L));
assertThat(decode_double(encode_double(0.0)), is(0.0));
@@ -116,7 +114,6 @@ public class BinaryFormatTestCase {
@Test
public void testTypeAndMetaMangling() {
- System.out.println("test type and meta mangling");
for (byte type = 0; type < TYPE_LIMIT; ++type) {
for (int meta = 0; meta < META_LIMIT; ++meta) {
byte mangled = encode_type_and_meta(type, meta);
@@ -126,10 +123,8 @@ public class BinaryFormatTestCase {
}
}
- // was testCmprUlong
@Test
- public void testCmprLong() {
- System.out.println("test compressed long");
+ public void testCompressedLong() {
{
long value = 0;
byte[] wanted = { 0 };
@@ -217,11 +212,8 @@ public class BinaryFormatTestCase {
// testWriteBytes -> buffered IO test
// testReadByte -> buffered IO test
// testReadBytes -> buffered IO test
-
@Test
- public void testTypeAndSize() {
- System.out.println("test type and size conversion");
-
+ public void testTypeAndSizeConversion() {
for (byte type = 0; type < TYPE_LIMIT; ++type) {
for (long size = 0; size < 500; ++size) {
BufferedOutput expect = new BufferedOutput();
@@ -271,8 +263,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testTypeAndBytes() {
- System.out.println("test encoding and decoding of type and bytes");
+ public void testEncodingAndDecodingOfTypeAndBytes() {
for (byte type = 0; type < TYPE_LIMIT; ++type) {
for (int n = 0; n < MAX_NUM_SIZE; ++n) {
for (int pre = 0; (pre == 0) || (pre < n); ++pre) {
@@ -307,9 +298,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testEmpty() {
- System.out.println("test encoding empty slime");
-
+ public void testEncodingEmptySlime() {
Slime slime = new Slime();
BufferedOutput expect = new BufferedOutput();
expect.put((byte)0); // num symbols
@@ -321,8 +310,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testBasic() {
- System.out.println("test encoding slime holding a single basic value");
+ public void testEncodingSlimeHoldingASingleBasicValue() {
{
Slime slime = new Slime();
slime.setBool(false);
@@ -427,8 +415,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testArray() {
- System.out.println("test encoding slime holding an array of various basic values");
+ public void testEncodingSlimeArray() {
Slime slime = new Slime();
Cursor c = slime.setArray();
byte[] data = { 'd', 'a', 't', 'a' };
@@ -452,8 +439,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testObject() {
- System.out.println("test encoding slime holding an object of various basic values");
+ public void testEncodingSlimeObject() {
Slime slime = new Slime();
Cursor c = slime.setObject();
byte[] data = { 'd', 'a', 't', 'a' };
@@ -478,8 +464,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testNesting() {
- System.out.println("test encoding slime holding a more complex structure");
+ public void testEncodingComplexSlimeStructure() {
Slime slime = new Slime();
Cursor c1 = slime.setObject();
c1.setLong("bar", 10);
@@ -503,8 +488,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testSymbolReuse() {
- System.out.println("test encoding slime reusing symbols");
+ public void testEncodingSlimeReusingSymbols() {
Slime slime = new Slime();
Cursor c1 = slime.setArray();
{
@@ -533,8 +517,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testOptionalDecodeOrder() {
- System.out.println("test decoding slime with different symbol order");
+ public void testDecodingSlimeWithDifferentSymbolOrder() {
byte[] data = {
5, // num symbols
1, 'd', 1, 'e', 1, 'f', 1, 'b', 1, 'c', // symbol table
@@ -564,4 +547,5 @@ public class BinaryFormatTestCase {
assertThat(c.field("f").asData(), is(expd));
assertThat(c.entry(5).valid(), is(false)); // not ARRAY
}
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
index f7a0a3cdb7d..d3bb702175a 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
@@ -58,10 +58,13 @@ public class TensorTypeTestCase {
@Test
public void requireThatIllegalSyntaxInSpecThrowsException() {
- assertIllegalTensorType("foo(x[10])", "Tensor type spec must start with 'tensor(' and end with ')', but was 'foo(x[10])'");
- assertIllegalTensorType("tensor(x_@[10])", "Failed parsing element 'x_@[10]' in type spec 'tensor(x_@[10])'");
- assertIllegalTensorType("tensor(x[10a])", "Failed parsing element 'x[10a]' in type spec 'tensor(x[10a])'");
- assertIllegalTensorType("tensor(x{10})", "Failed parsing element 'x{10}' in type spec 'tensor(x{10})'");
+ assertIllegalTensorType("foo(x[10])", "but was 'foo(x[10])'.");
+ assertIllegalTensorType("tensor(x_@[10])", "Dimension 'x_@[10]' is on the wrong format");
+ assertIllegalTensorType("tensor(x[10a])", "Dimension 'x[10a]' is on the wrong format");
+ assertIllegalTensorType("tensor(x{10})", "Dimension 'x{10}' is on the wrong format");
+ assertIllegalTensorType("tensor<(x{})", " Value type spec must be enclosed in <>");
+ assertIllegalTensorType("tensor<>(x{})", "Value type must be");
+ assertIllegalTensorType("tensor<notavalue>(x{})", "Value type must be");
}
@Test
@@ -88,6 +91,13 @@ public class TensorTypeTestCase {
assertIsConvertibleTo("tensor(x{},y[10])", "tensor(x{},y[])");
}
+ @Test
+ public void testValueType() {
+ assertValueType(TensorType.Value.DOUBLE, "tensor(x[])");
+ assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])");
+ assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])");
+ }
+
private static void assertTensorType(String typeSpec) {
assertTensorType(typeSpec, typeSpec);
}
@@ -121,4 +131,8 @@ public class TensorTypeTestCase {
assertFalse(TensorType.fromSpec(specificType).isConvertibleTo(TensorType.fromSpec(generalType)));
}
+ private void assertValueType(TensorType.Value expectedValueType, String tensorTypeSpec) {
+ assertEquals(expectedValueType, TensorType.fromSpec(tensorTypeSpec).valueType());
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index e8b17812f32..5d1bc7b0c3f 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -55,7 +55,7 @@ public class DenseBinaryFormatTestCase {
@Test
public void requireThatFloatSerializationFormatDoNotChange() {
- byte[] encodedTensor = new byte[]{4, // binary format type
+ byte[] encodedTensor = new byte[]{6, // binary format type
1, // float type
2, // dimension count
2, (byte) 'x', (byte) 'y', 2, // dimension xy with size
@@ -63,27 +63,21 @@ public class DenseBinaryFormatTestCase {
64, 0, 0, 0, // value 1
64, 64, 0, 0, // value 2
};
- Tensor tensor = Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}");
- tensor.type().valueType(TensorType.ValueType.FLOAT);
- assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ Tensor tensor = Tensor.from("tensor<float>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
}
@Test
public void testSerializationOfDifferentValueTypes() {
- assertSerialization(TensorType.ValueType.DOUBLE, "tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
- assertSerialization(TensorType.ValueType.FLOAT, "tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<double>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<float>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
}
private void assertSerialization(String tensorString) {
- assertSerialization(TensorType.ValueType.DOUBLE, Tensor.from(tensorString));
- }
- private void assertSerialization(TensorType.ValueType valueType, String tensorString) {
- assertSerialization(valueType, Tensor.from(tensorString));
+ assertSerialization(Tensor.from(tensorString));
}
- private void assertSerialization(TensorType.ValueType valueType, Tensor tensor) {
- tensor.type().valueType(valueType);
+ private void assertSerialization(Tensor tensor) {
assertSerialization(tensor, tensor.type());
}
diff --git a/vespalog/src/vespa/log/log.cpp b/vespalog/src/vespa/log/log.cpp
index 8e3ed9a18ba..a43c6dd0416 100644
--- a/vespalog/src/vespa/log/log.cpp
+++ b/vespalog/src/vespa/log/log.cpp
@@ -379,7 +379,7 @@ Logger::doEventProgress(const char *name, double value, double total)
void
Logger::doEventCount(const char *name, uint64_t value)
{
- doLog(event, "", 0, "count/1 name=\"%s\" value=%lu", name, value);
+ doLog(event, "", 0, "count/1 name=\"%s\" value=%" PRIu64, name, value);
}
void
diff --git a/vespalog/src/vespa/log/log_message.cpp b/vespalog/src/vespa/log/log_message.cpp
index 77f9b619e9f..8ce7df93a12 100644
--- a/vespalog/src/vespa/log/log_message.cpp
+++ b/vespalog/src/vespa/log/log_message.cpp
@@ -31,18 +31,29 @@ find_tab(std::string_view log_line, const char *tab_name, std::string_view::size
}
int64_t
-parse_time_field(std::string time_field)
+parse_time_subfield(std::string time_subfield, const std::string &time_field)
{
- std::istringstream time_stream(time_field);
- time_stream.imbue(clocale);
- double logtime = 0;
- time_stream >> logtime;
- if (!time_stream.eof()) {
+ std::istringstream subfield_stream(time_subfield);
+ subfield_stream.imbue(clocale);
+ int64_t result = 0;
+ subfield_stream >> result;
+ if (!subfield_stream.eof()) {
std::ostringstream os;
os << "Bad time field: " << time_field;
throw BadLogLineException(os.str());
}
- return logtime * 1000000000;
+ return result;
+}
+
+int64_t
+parse_time_field(std::string time_field)
+{
+ auto dotPos = time_field.find('.');
+ int64_t log_time = parse_time_subfield(time_field.substr(0, dotPos), time_field) * 1000000000;
+ if (dotPos != std::string::npos) {
+ log_time += parse_time_subfield((time_field.substr(dotPos + 1) + "000000000").substr(0, 9), time_field);
+ }
+ return log_time;
}
struct PidFieldParser