summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--athenz-identity-provider-service/pom.xml8
-rw-r--r--athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java10
-rw-r--r--athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGenerator.java28
-rw-r--r--athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentResource.java18
-rw-r--r--athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidator.java1
-rw-r--r--athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGeneratorTest.java9
-rw-r--r--athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidatorTest.java7
-rw-r--r--bundle-plugin-test/pom.xml8
-rw-r--r--bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.scala6
-rw-r--r--bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeMethodVisitor.scala2
-rw-r--r--bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeSignatureVisitor.scala2
-rw-r--r--bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnnotationVisitorTrait.scala2
-rw-r--r--bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/package.scala5
-rw-r--r--bundle-plugin/src/main/scala/com/yahoo/container/plugin/mojo/GenerateOsgiManifestMojo.scala7
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/application/api/ValidationOverrides.java9
-rw-r--r--config-model-fat/pom.xml9
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java674
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java636
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java677
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/Host.java10
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java7
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java22
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java26
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ClusterSizeReductionValidatorTest.java5
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java5
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentTypeRemovalValidatorTest.java5
-rw-r--r--config/src/main/java/com/yahoo/config/subscription/impl/ConfigSubscription.java2
-rw-r--r--config/src/main/java/com/yahoo/config/subscription/impl/JRTConfigRequester.java4
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java24
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerBootstrap.java79
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java (renamed from configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java)30
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionImpl.java6
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java6
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java2
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/FileDistributionMaintainer.java5
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/rpc/RpcServer.java2
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java5
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/tenant/Tenant.java1
-rw-r--r--configserver/src/main/resources/configserver-app/services.xml6
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/ConfigServerBootstrapTest.java17
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/TestWithCurator.java29
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/application/ConfigConvergenceCheckerTest.java (renamed from configserver/src/test/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceCheckerTest.java)47
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java17
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/deploy/ZooKeeperClientTest.java6
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java10
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionRepoTest.java6
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionRepoTest.java15
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionTest.java23
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionFactoryTest.java8
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java20
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClientTest.java18
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java14
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRequestHandlerTest.java32
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java6
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TestWithTenant.java25
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/InitializedCounterTest.java18
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackageTest.java12
-rw-r--r--container-core/src/main/java/com/yahoo/container/handler/VipStatus.java18
-rw-r--r--container-core/src/main/resources/configdefinitions/vip-status.def7
-rw-r--r--container-dependency-versions/pom.xml2
-rw-r--r--container-dev/pom.xml4
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetrics.java94
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java6
-rw-r--r--container-disc/src/test/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetricsTest.java57
-rw-r--r--container-disc/src/test/java/com/yahoo/container/jdisc/metric/MetricUpdaterTest.java5
-rw-r--r--container-jersey2/pom.xml5
-rw-r--r--container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ComponentGraphProvider.java73
-rw-r--r--container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyApplication.java25
-rw-r--r--container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyServletProvider.java118
-rw-r--r--container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.java103
-rw-r--r--container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ComponentGraphProvider.scala40
-rw-r--r--container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyApplication.scala16
-rw-r--r--container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyServletProvider.scala109
-rw-r--r--container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.scala74
-rw-r--r--container/pom.xml6
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/Organization.java8
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java8
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java47
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java218
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationActivity.java56
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationPackage.java3
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java61
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentActivity.java55
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/concurrent/Locks.java55
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentOrder.java20
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java26
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainer.java14
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/Upgrader.java4
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java32
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java13
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java5
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java34
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/MockMetricsService.java19
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainerTest.java64
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java41
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java4
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json4
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json4
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json4
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json4
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json6
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-corp-us-east-1.json4
-rw-r--r--docprocs/src/test/java/com/yahoo/docprocs/indexing/DocumentScriptTestCase.java91
-rw-r--r--document/src/main/java/com/yahoo/document/datatypes/Array.java3
-rw-r--r--fat-model-dependencies/pom.xml7
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldPathUpdateHelper.java17
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/IdentityFieldPathUpdateAdapter.java68
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/SimpleAdapterFactory.java6
-rw-r--r--jdisc_http_service/pom.xml3
-rw-r--r--jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerConformanceTest.java2
-rw-r--r--jdisc_jetty/pom.xml4
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java4
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java104
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java2
-rw-r--r--node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoredumpHandler.java14
-rw-r--r--node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/FileHelper.java12
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/GroupPreparer.java2
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DockerProvisioningTest.java4
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java3
-rw-r--r--orchestrator/src/main/java/com/yahoo/vespa/orchestrator/ServiceMonitorInstanceLookupService.java2
-rw-r--r--parent/pom.xml10
-rw-r--r--searchcore/src/tests/proton/matching/query_test.cpp53
-rw-r--r--searchcore/src/tests/proton/matching/querynodes_test.cpp57
-rw-r--r--searchcore/src/tests/proton/matching/resolveviewvisitor_test.cpp17
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java)101
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java242
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java30
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java47
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java)9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java)10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java107
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java)154
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java216
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java)6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java52
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java)19
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java)31
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java)53
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java)31
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java)21
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java)13
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java)118
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java)22
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java)11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java)15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java)29
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java)15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java)11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java)24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java)19
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java)13
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java)33
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java)22
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java85
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java234
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java)3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java72
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java)2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java326
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java112
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java64
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java139
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java411
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java210
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java97
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java255
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java145
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java74
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java46
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/Benchmark.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/StreamEvaluationBenchmark.java5
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java)6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java)22
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java)14
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java)2
-rw-r--r--searchsummary/CMakeLists.txt1
-rw-r--r--searchsummary/src/tests/docsummary/attribute_combiner/CMakeLists.txt8
-rw-r--r--searchsummary/src/tests/docsummary/attribute_combiner/attribute_combiner_test.cpp217
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/CMakeLists.txt3
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.cpp89
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.h29
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.cpp141
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.h36
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.cpp172
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.h34
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_state.h21
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.cpp6
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.h1
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/docsumstate.cpp2
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/docsumstate.h2
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.cpp7
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.h1
-rw-r--r--service-monitor/pom.xml23
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/ServiceStatusProvider.java8
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGenerator.java141
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerAppGenerator.java67
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerApplication.java50
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/DeployedAppGenerator.java127
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/HostsModel.java75
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ZoneApplication.java4
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModel.java42
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModelListener.java28
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ModelGenerator.java38
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/MonitorManager.java5
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceId.java75
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceMonitorImpl.java23
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImpl.java16
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManager.java41
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitor.java102
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthClient.java139
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthEndpoint.java57
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthInfo.java75
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitor.java73
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManager.java42
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthResponse.java35
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImpl.java29
-rw-r--r--service-monitor/src/test/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGeneratorTest.java (renamed from service-monitor/src/test/java/com/yahoo/vespa/service/monitor/application/ConfigServerAppGeneratorTest.java)19
-rw-r--r--service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ConfigserverUtil.java52
-rw-r--r--service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/DuperModelTest.java53
-rw-r--r--service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ModelGeneratorTest.java44
-rw-r--r--service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImplTest.java13
-rw-r--r--service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManagerTest.java97
-rw-r--r--service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitorTest.java24
-rw-r--r--service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManagerTest.java49
-rw-r--r--service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorTest.java21
-rw-r--r--service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImplTest.java11
-rw-r--r--valgrind-suppressions.txt17
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java43
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java2
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityType.java25
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java39
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceId.java46
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentApi.java6
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java2
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java41
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/VespaUniqueInstanceIdEntity.java26
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java3
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java16
-rw-r--r--vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceIdTest.java13
-rw-r--r--vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java10
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/CommandLineArguments.java24
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/runner/CommandLineArgumentsTest.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java123
252 files changed, 6415 insertions, 5215 deletions
diff --git a/athenz-identity-provider-service/pom.xml b/athenz-identity-provider-service/pom.xml
index 86d4defa861..982cb89f2bf 100644
--- a/athenz-identity-provider-service/pom.xml
+++ b/athenz-identity-provider-service/pom.xml
@@ -131,6 +131,14 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
+ <configuration>
+ <compilerArgs>
+ <arg>-Xlint:all</arg>
+ <arg>-Xlint:-deprecation</arg>
+ <arg>-Xlint:-serial</arg>
+ <arg>-Werror</arg>
+ </compilerArgs>
+ </configuration>
</plugin>
</plugins>
</build>
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java
index f1fc938d3ea..2a517e06ae2 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java
+++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java
@@ -23,11 +23,11 @@ import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.PrivateKey;
-import java.security.SecureRandom;
import java.security.cert.X509Certificate;
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
+import java.util.UUID;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
@@ -45,7 +45,6 @@ import static com.yahoo.vespa.hosted.athenz.instanceproviderservice.impl.Utils.g
@SuppressWarnings("unused") // Component injected into Jetty connector factory
public class AthenzSslKeyStoreConfigurator extends AbstractComponent implements SslKeyStoreConfigurator {
private static final Logger log = Logger.getLogger(AthenzSslKeyStoreConfigurator.class.getName());
- private static final SecureRandom secureRandom = new SecureRandom();
private static final String CERTIFICATE_ALIAS = "athenz";
private static final Duration EXPIRATION_MARGIN = Duration.ofHours(6);
@@ -172,12 +171,7 @@ public class AthenzSslKeyStoreConfigurator extends AbstractComponent implements
}
private static char[] generateKeystorePassword() {
- int length = 128;
- char[] pwd = new char[length];
- for (int i = 0; i < length; i++) {
- pwd[i] = (char) secureRandom.nextInt();
- }
- return pwd;
+ return UUID.randomUUID().toString().toCharArray();
}
private class AthenzCertificateUpdater implements Runnable {
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGenerator.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGenerator.java
index 728406c297f..59126fd023f 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGenerator.java
+++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGenerator.java
@@ -7,6 +7,7 @@ import com.yahoo.net.HostName;
import com.yahoo.vespa.athenz.api.AthenzService;
import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper;
import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocument;
+import com.yahoo.vespa.athenz.identityprovider.api.IdentityType;
import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument;
import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId;
import com.yahoo.vespa.hosted.athenz.instanceproviderservice.KeyProvider;
@@ -27,7 +28,10 @@ import java.util.Objects;
import java.util.Set;
/**
+ * Generates a signed identity document for a given hostname and type
+ *
* @author mortent
+ * @author bjorncs
*/
public class IdentityDocumentGenerator {
@@ -47,10 +51,10 @@ public class IdentityDocumentGenerator {
this.keyProvider = keyProvider;
}
- public SignedIdentityDocument generateSignedIdentityDocument(String hostname) {
+ public SignedIdentityDocument generateSignedIdentityDocument(String hostname, IdentityType identityType) {
Node node = nodeRepository.getNode(hostname).orElseThrow(() -> new RuntimeException("Unable to find node " + hostname));
try {
- IdentityDocument identityDocument = generateIdDocument(node);
+ IdentityDocument identityDocument = generateIdDocument(node, identityType);
String identityDocumentString = Utils.getMapper().writeValueAsString(EntityBindingsMapper.toIdentityDocumentEntity(identityDocument));
String encodedIdentityDocument =
@@ -70,13 +74,18 @@ public class IdentityDocumentGenerator {
toZoneDnsSuffix(zone, zoneConfig.certDnsSuffix()),
new AthenzService(zoneConfig.domain(), zoneConfig.serviceName()),
URI.create(zoneConfig.ztsUrl()),
- SignedIdentityDocument.DEFAULT_DOCUMENT_VERSION);
+ SignedIdentityDocument.DEFAULT_DOCUMENT_VERSION,
+ identityDocument.configServerHostname(),
+ identityDocument.instanceHostname(),
+ identityDocument.createdAt(),
+ identityDocument.ipAddresses(),
+ identityType);
} catch (Exception e) {
throw new RuntimeException("Exception generating identity document: " + e.getMessage(), e);
}
}
- private IdentityDocument generateIdDocument(Node node) {
+ private IdentityDocument generateIdDocument(Node node, IdentityType identityType) {
Allocation allocation = node.allocation().orElseThrow(() -> new RuntimeException("No allocation for node " + node.hostname()));
VespaUniqueInstanceId providerUniqueId = new VespaUniqueInstanceId(
allocation.membership().index(),
@@ -85,17 +94,10 @@ public class IdentityDocumentGenerator {
allocation.owner().application().value(),
allocation.owner().tenant().value(),
zone.region().value(),
- zone.environment().value());
+ zone.environment().value(),
+ identityType);
- // TODO: Hack to allow access from docker containers to non-ipv6 services.
- // Remove when yca-bridge is no longer needed
Set<String> ips = new HashSet<>(node.ipAddresses());
- if(node.parentHostname().isPresent()) {
- String parentHostName = node.parentHostname().get();
- nodeRepository.getNode(parentHostName)
- .map(Node::ipAddresses)
- .ifPresent(ips::addAll);
- }
return new IdentityDocument(
providerUniqueId,
HostName.getLocalhost(),
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentResource.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentResource.java
index 93668006e26..219e12c7223 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentResource.java
+++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentResource.java
@@ -6,6 +6,7 @@ import com.yahoo.container.jaxrs.annotation.Component;
import com.yahoo.jdisc.http.servlet.ServletRequest;
import com.yahoo.log.LogLevel;
import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper;
+import com.yahoo.vespa.athenz.identityprovider.api.IdentityType;
import com.yahoo.vespa.athenz.identityprovider.api.bindings.IdentityDocumentApi;
import com.yahoo.vespa.athenz.identityprovider.api.bindings.SignedIdentityDocumentEntity;
import com.yahoo.vespa.hosted.provision.restapi.v2.filter.NodePrincipal;
@@ -18,7 +19,6 @@ import javax.ws.rs.InternalServerErrorException;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
-import javax.ws.rs.QueryParam;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import java.util.logging.Logger;
@@ -41,15 +41,7 @@ public class IdentityDocumentResource implements IdentityDocumentApi {
this.request = request;
}
- /**
- * @deprecated Use {@link #getNodeIdentityDocument(String)} and {@link #getTenantIdentityDocument(String)} instead.
- */
- @GET
- @Produces(MediaType.APPLICATION_JSON)
- @Deprecated
- @Override
- // TODO Make this method private when the rest api is not longer in use
- public SignedIdentityDocumentEntity getIdentityDocument(@QueryParam("hostname") String hostname) {
+ private SignedIdentityDocumentEntity getIdentityDocument(String hostname, IdentityType identityType) {
if (hostname == null) {
throw new BadRequestException("The 'hostname' query parameter is missing");
}
@@ -67,7 +59,7 @@ public class IdentityDocumentResource implements IdentityDocumentApi {
throw new ForbiddenException();
}
try {
- return EntityBindingsMapper.toSignedIdentityDocumentEntity(identityDocumentGenerator.generateSignedIdentityDocument(hostname));
+ return EntityBindingsMapper.toSignedIdentityDocumentEntity(identityDocumentGenerator.generateSignedIdentityDocument(hostname, identityType));
} catch (Exception e) {
String message = String.format("Unable to generate identity doument for '%s': %s", hostname, e.getMessage());
log.log(LogLevel.ERROR, message, e);
@@ -80,7 +72,7 @@ public class IdentityDocumentResource implements IdentityDocumentApi {
@Path("/node/{host}")
@Override
public SignedIdentityDocumentEntity getNodeIdentityDocument(@PathParam("host") String host) {
- return getIdentityDocument(host);
+ return getIdentityDocument(host, IdentityType.NODE);
}
@GET
@@ -88,7 +80,7 @@ public class IdentityDocumentResource implements IdentityDocumentApi {
@Path("/tenant/{host}")
@Override
public SignedIdentityDocumentEntity getTenantIdentityDocument(@PathParam("host") String host) {
- return getIdentityDocument(host);
+ return getIdentityDocument(host, IdentityType.TENANT);
}
}
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidator.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidator.java
index e457df37946..0201c46b253 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidator.java
+++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidator.java
@@ -82,6 +82,7 @@ public class InstanceValidator {
}
// If/when we dont care about logging exactly whats wrong, this can be simplified
+ // TODO Use identity type to determine if this check should be performed
boolean isSameIdentityAsInServicesXml(ApplicationId applicationId, String domain, String service) {
Optional<ApplicationInfo> applicationInfo = superModelProvider.getSuperModel().getApplicationInfo(applicationId);
diff --git a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGeneratorTest.java b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGeneratorTest.java
index d7b061ca2f1..078ef1b7e39 100644
--- a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGeneratorTest.java
+++ b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGeneratorTest.java
@@ -15,6 +15,7 @@ import com.yahoo.config.provision.SystemName;
import com.yahoo.config.provision.TenantName;
import com.yahoo.config.provision.Zone;
import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper;
+import com.yahoo.vespa.athenz.identityprovider.api.IdentityType;
import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument;
import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId;
import com.yahoo.vespa.athenz.identityprovider.api.bindings.SignedIdentityDocumentEntity;
@@ -81,7 +82,7 @@ public class IdentityDocumentGeneratorTest {
AthenzProviderServiceConfig config = getAthenzProviderConfig("domain", "service", dnsSuffix, ZONE);
IdentityDocumentGenerator identityDocumentGenerator =
new IdentityDocumentGenerator(config, nodeRepository, ZONE, keyProvider);
- SignedIdentityDocument signedIdentityDocument = identityDocumentGenerator.generateSignedIdentityDocument(containerHostname);
+ SignedIdentityDocument signedIdentityDocument = identityDocumentGenerator.generateSignedIdentityDocument(containerHostname, IdentityType.TENANT);
// Verify attributes
assertEquals(containerHostname, signedIdentityDocument.identityDocument().instanceHostname());
@@ -92,11 +93,11 @@ public class IdentityDocumentGeneratorTest {
assertEquals(expectedZoneDnsSuffix, signedIdentityDocument.dnsSuffix());
VespaUniqueInstanceId expectedProviderUniqueId =
- new VespaUniqueInstanceId(0, "default", "default", "application", "tenant", region, environment);
+ new VespaUniqueInstanceId(0, "default", "default", "application", "tenant", region, environment, IdentityType.TENANT);
assertEquals(expectedProviderUniqueId, signedIdentityDocument.providerUniqueId());
- // Validate that both parent and container ips are present
- assertThat(signedIdentityDocument.identityDocument().ipAddresses(), Matchers.containsInAnyOrder("127.0.0.1", "::1"));
+ // Validate that container ips are present
+ assertThat(signedIdentityDocument.identityDocument().ipAddresses(), Matchers.containsInAnyOrder("::1"));
SignedIdentityDocumentEntity signedIdentityDocumentEntity = EntityBindingsMapper.toSignedIdentityDocumentEntity(signedIdentityDocument);
diff --git a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidatorTest.java b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidatorTest.java
index 54786c86cd3..54411b424eb 100644
--- a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidatorTest.java
+++ b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidatorTest.java
@@ -143,7 +143,12 @@ public class InstanceValidatorTest {
"dnssuffix",
"service",
URI.create("http://localhost/zts"),
- 1));
+ 1,
+ identityDocument.configServerHostname,
+ identityDocument.instanceHostname,
+ identityDocument.createdAt,
+ identityDocument.ipAddresses,
+ null)); // TODO Remove support for legacy representation without type
} catch (Exception e) {
throw new RuntimeException(e);
}
diff --git a/bundle-plugin-test/pom.xml b/bundle-plugin-test/pom.xml
index 5ae5496b1b0..53be71352c8 100644
--- a/bundle-plugin-test/pom.xml
+++ b/bundle-plugin-test/pom.xml
@@ -48,6 +48,14 @@
<artifactId>scala-library</artifactId>
<scope>provided</scope>
</dependency>
+
+ <dependency>
+ <!-- Added to verify that module-info.class can be handled by bundle-plugin without throwing an exception. -->
+ <groupId>javax.xml.bind</groupId>
+ <artifactId>jaxb-api</artifactId>
+ <version>2.3.0</version>
+ </dependency>
+
</dependencies>
<build>
<plugins>
diff --git a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.scala b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.scala
index 903ad94e9e8..539684f2024 100644
--- a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.scala
+++ b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.scala
@@ -9,7 +9,7 @@ import collection.mutable
* Picks up classes used in class files.
* @author tonytv
*/
-private class AnalyzeClassVisitor extends ClassVisitor(Opcodes.ASM5) with AnnotationVisitorTrait with AttributeVisitorTrait {
+private class AnalyzeClassVisitor extends ClassVisitor(Opcodes.ASM6) with AnnotationVisitorTrait with AttributeVisitorTrait {
private var name : String = null
protected val imports : ImportsSet = mutable.Set()
protected var exportPackageAnnotation: Option[ExportPackageAnnotation] = None
@@ -32,7 +32,7 @@ private class AnalyzeClassVisitor extends ClassVisitor(Opcodes.ASM5) with Annota
imports ++= getClassName(Type.getType(desc)).toList
AnalyzeSignatureVisitor.analyzeField(signature, this)
- new FieldVisitor(Opcodes.ASM5) with SubVisitorTrait with AttributeVisitorTrait with AnnotationVisitorTrait {
+ new FieldVisitor(Opcodes.ASM6) with SubVisitorTrait with AttributeVisitorTrait with AnnotationVisitorTrait {
val analyzeClassVisitor = AnalyzeClassVisitor.this
override def visitAnnotation(desc: String, visible: Boolean): AnnotationVisitor = super.visitAnnotation(desc, visible)
@@ -68,7 +68,7 @@ private class AnalyzeClassVisitor extends ClassVisitor(Opcodes.ASM5) with Annota
def visitExportPackage(): AnnotationVisitor = {
def defaultVersionValue[T](name: String) = classOf[Version].getMethod(name).getDefaultValue().asInstanceOf[T]
- new AnnotationVisitor(Opcodes.ASM5) {
+ new AnnotationVisitor(Opcodes.ASM6) {
var major: Int = defaultVersionValue("major")
var minor: Int = defaultVersionValue("minor")
var micro: Int = defaultVersionValue("micro")
diff --git a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeMethodVisitor.scala b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeMethodVisitor.scala
index 535ee2832c8..a8032b6a912 100644
--- a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeMethodVisitor.scala
+++ b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeMethodVisitor.scala
@@ -8,7 +8,7 @@ import org.objectweb.asm._
* @author tonytv
*/
private class AnalyzeMethodVisitor(val analyzeClassVisitor : AnalyzeClassVisitor)
- extends MethodVisitor(Opcodes.ASM5) with AnnotationVisitorTrait with AttributeVisitorTrait with SubVisitorTrait {
+ extends MethodVisitor(Opcodes.ASM6) with AnnotationVisitorTrait with AttributeVisitorTrait with SubVisitorTrait {
override def visitParameterAnnotation(parameter: Int, desc: String, visible: Boolean): AnnotationVisitor = super.visitParameterAnnotation(parameter, desc, visible)
diff --git a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeSignatureVisitor.scala b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeSignatureVisitor.scala
index 58a43b04d20..5bb8304cf1e 100644
--- a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeSignatureVisitor.scala
+++ b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeSignatureVisitor.scala
@@ -10,7 +10,7 @@ import org.objectweb.asm.signature.{SignatureReader, SignatureVisitor}
*/
private class AnalyzeSignatureVisitor(val analyzeClassVisitor: AnalyzeClassVisitor)
- extends SignatureVisitor(Opcodes.ASM5)
+ extends SignatureVisitor(Opcodes.ASM6)
with SubVisitorTrait {
diff --git a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnnotationVisitorTrait.scala b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnnotationVisitorTrait.scala
index 0ceaced1440..0bf6ee4a6b4 100644
--- a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnnotationVisitorTrait.scala
+++ b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnnotationVisitorTrait.scala
@@ -17,7 +17,7 @@ private trait AnnotationVisitorTrait {
}
def visitAnnotationDefault(): AnnotationVisitor =
- new AnnotationVisitor(Opcodes.ASM5) {
+ new AnnotationVisitor(Opcodes.ASM6) {
override def visit(name: String, value: AnyRef) {}
override def visitEnum(name: String, desc: String, value: String) {
diff --git a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/package.scala b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/package.scala
index d217f720d1a..631884c58e3 100644
--- a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/package.scala
+++ b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/package.scala
@@ -8,7 +8,10 @@ package object classanalysis {
type ImportsSet = mutable.Set[String]
def internalNameToClassName(internalClassName: String) : Option[String] = {
- getClassName(Type.getObjectType(internalClassName))
+ internalClassName match {
+ case null => None
+ case _ => getClassName(Type.getObjectType(internalClassName))
+ }
}
def getClassName(aType: Type): Option[String] = {
diff --git a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/mojo/GenerateOsgiManifestMojo.scala b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/mojo/GenerateOsgiManifestMojo.scala
index d66edf88702..67ce45ed7c6 100644
--- a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/mojo/GenerateOsgiManifestMojo.scala
+++ b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/mojo/GenerateOsgiManifestMojo.scala
@@ -210,7 +210,7 @@ class GenerateOsgiManifestMojo extends AbstractMojo {
private def analyzeProjectClasses() : PackageTally = {
val outputDirectory = new File(project.getBuild.getOutputDirectory)
- val analyzedClasses = allDescendantFiles(outputDirectory).filter(_.getName.endsWith(".class")).
+ val analyzedClasses = allDescendantFiles(outputDirectory).filter(file => isClassToAnalyze(file.getName)).
map(Analyze.analyzeClass)
PackageTally.fromAnalyzedClassFiles(analyzedClasses)
@@ -230,7 +230,7 @@ class GenerateOsgiManifestMojo extends AbstractMojo {
for {
entry <- toStream(jarFile.entries())
if !entry.isDirectory
- if entry.getName.endsWith(".class")
+ if isClassToAnalyze(entry.getName)
metaData = analyzeClass(jarFile, entry)
} yield metaData
@@ -278,6 +278,9 @@ object GenerateOsgiManifestMojo {
}
}
+ def isClassToAnalyze(name: String): Boolean =
+ name.endsWith(".class") && ! name.endsWith("module-info.class")
+
def emptyToNone(str: String) =
Option(str) map {_.trim} filterNot {_.isEmpty}
}
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ValidationOverrides.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ValidationOverrides.java
index 11f9add6b25..441ef273a6f 100644
--- a/config-model-api/src/main/java/com/yahoo/config/application/api/ValidationOverrides.java
+++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ValidationOverrides.java
@@ -66,6 +66,11 @@ public class ValidationOverrides {
return false;
}
+ public static String toAllowMessage(ValidationId id) {
+ return "To allow this add <allow until='yyyy-mm-dd'>" + id + "</allow> to validation-overrides.xml" +
+ ", see https://docs.vespa.ai/documentation/reference/validation-overrides.html";
+ }
+
/** Returns the XML form of this, or null if it was not created by fromXml, nor is empty */
public String xmlForm() { return xmlForm; }
@@ -155,7 +160,9 @@ public class ValidationOverrides {
/** Returns "validationId: message" */
@Override
- public String getMessage() { return validationId + ": " + super.getMessage(); }
+ public String getMessage() {
+ return validationId + ": " + super.getMessage() + ". " + toAllowMessage(validationId);
+ }
}
diff --git a/config-model-fat/pom.xml b/config-model-fat/pom.xml
index 3ef9925510c..649d8a37bf6 100644
--- a/config-model-fat/pom.xml
+++ b/config-model-fat/pom.xml
@@ -25,6 +25,13 @@
<artifactId>guava</artifactId>
<version>13.0.1</version>
</dependency>
+ <dependency>
+ <!-- TODO: can probably be removed. Added to get the same set of embedded deps with maven-bundle-plugin 3.5 as with 2.4. -->
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>annotations</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+
<dependency>
<groupId>com.yahoo.vespa</groupId>
@@ -114,8 +121,6 @@
<plugin>
<groupId>org.apache.felix</groupId>
<artifactId>maven-bundle-plugin</artifactId>
- <!-- version >= 2.5.0 causes java.lang.ArrayIndexOutOfBoundsException: 176 -->
- <version>2.4.0</version>
<extensions>true</extensions>
<configuration>
<instructions>
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java
new file mode 100644
index 00000000000..effa261be3b
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java
@@ -0,0 +1,674 @@
+package com.yahoo.searchdefinition.expressiontransforms;
+
+import com.google.common.base.Joiner;
+import com.yahoo.collections.Pair;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.FeatureNames;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.StringReader;
+import java.io.UncheckedIOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Base class for replacing instances of a pseudofeature for imported ML
+ * ranking models with native Vespa ranking expressions.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
+
+ ExpressionNode transformFromImportedModel(ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles) {
+ // Add constants
+ Set<String> constantsReplacedByMacros = new HashSet<>();
+ model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
+ model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
+ constantsReplacedByMacros, k, v));
+
+ // Find the specified expression
+ ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature());
+ String output = chooseOutput(signature, store.arguments().output());
+ if (signature.skippedOutputs().containsKey(output)) {
+ String message = "Could not import model output '" + output + "'";
+ if (!signature.skippedOutputs().get(output).isEmpty()) {
+ message += ": " + signature.skippedOutputs().get(output);
+ }
+ if (!signature.importWarnings().isEmpty()) {
+ message += ": " + String.join(", ", signature.importWarnings());
+ }
+ throw new IllegalArgumentException(message);
+ }
+
+ RankingExpression expression = model.expressions().get(output);
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ verifyRequiredMacros(expression, model, profile, queryProfiles);
+ addGeneratedMacros(model, profile);
+ reduceBatchDimensions(expression, model, profile, queryProfiles);
+
+ model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
+
+ store.writeConverted(expression);
+ return expression.getRoot();
+ }
+
+ ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
+ for (Pair<String, Tensor> constant : store.readSmallConstants())
+ profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
+
+ for (RankingConstant constant : store.readLargeConstants()) {
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
+ profile.getSearch().addRankingConstant(constant);
+ }
+
+ for (Pair<String, RankingExpression> macro : store.readMacros()) {
+ addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
+ }
+
+ return store.readConverted().getRoot();
+ }
+
+ /**
+ * Returns the specified, existing signature, or the only signature if none is specified.
+ * Throws IllegalArgumentException in all other cases.
+ */
+ private ImportedModel.Signature chooseSignature(ImportedModel importResult, Optional<String> signatureName) {
+ if ( ! signatureName.isPresent()) {
+ if (importResult.signatures().size() == 0)
+ throw new IllegalArgumentException("No signatures are available");
+ if (importResult.signatures().size() > 1)
+ throw new IllegalArgumentException("Model has multiple signatures (" +
+ Joiner.on(", ").join(importResult.signatures().keySet()) +
+ "), one must be specified " +
+ "as a second argument to tensorflow()");
+ return importResult.signatures().values().stream().findFirst().get();
+ }
+ else {
+ ImportedModel.Signature signature = importResult.signatures().get(signatureName.get());
+ if (signature == null)
+ throw new IllegalArgumentException("Model does not have the specified signature '" +
+ signatureName.get() + "'");
+ return signature;
+ }
+ }
+
+ /**
+ * Returns the specified, existing output expression, or the only output expression if no output name is specified.
+ * Throws IllegalArgumentException in all other cases.
+ */
+ private String chooseOutput(ImportedModel.Signature signature, Optional<String> outputName) {
+ if ( ! outputName.isPresent()) {
+ if (signature.outputs().size() == 0)
+ throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature));
+ if (signature.outputs().size() > 1)
+ throw new IllegalArgumentException(signature + " has multiple outputs (" +
+ Joiner.on(", ").join(signature.outputs().keySet()) +
+ "), one must be specified " +
+ "as a third argument to tensorflow()");
+ return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get());
+ }
+ else {
+ String output = signature.outputs().get(outputName.get());
+ if (output == null) {
+ if (signature.skippedOutputs().containsKey(outputName.get()))
+ throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
+ signature.skippedOutputs().get(outputName.get()));
+ else
+ throw new IllegalArgumentException("Model does not have the specified output '" +
+ outputName.get() + "'");
+ }
+ return output;
+ }
+ }
+
+ private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ store.writeSmallConstant(constantName, constantValue);
+ profile.addConstant(constantName, asValue(constantValue));
+ }
+
+ private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
+ Set<String> constantsReplacedByMacros,
+ String constantName, Tensor constantValue) {
+ RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
+ if (macroOverridingConstant != null) {
+ TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
+ if ( ! macroType.equals(constantValue.type()))
+ throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
+ typeMismatchExplanation(constantValue.type(), macroType));
+ constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
+ }
+ else {
+ Path constantPath = store.writeLargeConstant(constantName, constantValue);
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
+ profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
+ constantPath.toString()));
+ }
+ }
+ }
+
+ private void transformGeneratedMacro(ModelStore store,
+ Set<String> constantsReplacedByMacros,
+ String macroName, RankingExpression expression) {
+
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ store.writeMacro(macroName, expression);
+ }
+
+ private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
+ if (profile.getMacros().containsKey(macroName)) {
+ throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists.");
+ }
+ profile.addMacro(macroName, false); // todo: inline if only used once
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ macro.setRankingExpression(expression);
+ macro.setTextualExpression(expression.getRoot().toString());
+ }
+
+ private String skippedOutputsDescription(ImportedModel.Signature signature) {
+ if (signature.skippedOutputs().isEmpty()) return "";
+ StringBuilder b = new StringBuilder(": ");
+ signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
+ return b.toString();
+ }
+
+ /**
+ * Verify that the macros referred in the given expression exists in the given rank profile,
+ * and return tensors of the types specified in requiredMacros.
+ */
+ private void verifyRequiredMacros(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ TensorType requiredType = model.requiredMacros().get(macroName);
+ if (requiredType == null) continue; // Not a required macro
+
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null)
+ throw new IllegalArgumentException("Model refers input '" + macroName +
+ "' of type " + requiredType + " but this macro is not present in " +
+ profile);
+ // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
+ // phase and summary features), as it may only resolve correctly given those bindings
+ // Or, probably better, annotate the macros with type constraints here and verify during general
+ // type verification
+ TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
+ if ( actualType == null)
+ throw new IllegalArgumentException("Model refers input '" + macroName +
+ "' of type " + requiredType +
+ " which must be produced by a macro in the rank profile, but " +
+ "this macro references a feature which is not declared");
+ if ( ! actualType.isAssignableTo(requiredType))
+ throw new IllegalArgumentException("Model refers input '" + macroName + "'. " +
+ typeMismatchExplanation(requiredType, actualType));
+ }
+ }
+
+ private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
+ return "The required type of this is " + requiredType + ", but this macro returns " + actualType +
+ (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " +
+ "in query profile types - see the documentation."
+ : "");
+ }
+
+ /**
+ * Add the generated macros to the rank profile
+ */
+ private void addGeneratedMacros(ImportedModel model, RankProfile profile) {
+ model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
+ }
+
+ /**
+ * Check if batch dimensions of inputs can be reduced out. If the input
+ * macro specifies that a single exemplar should be evaluated, we can
+ * reduce the batch dimension out.
+ */
+ private void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
+ TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
+
+ // Check generated macros for inputs to reduce
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ if ( ! model.macros().containsKey(macroName)) {
+ continue;
+ }
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null) {
+ throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
+ "but this macro is not present in " + profile);
+ }
+ RankingExpression macroExpression = macro.getRankingExpression();
+ macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
+ }
+
+ // Check expression for inputs to reduce
+ ExpressionNode root = expression.getRoot();
+ root = reduceBatchDimensionsAtInput(root, model, typeContext);
+ TensorType typeAfterReducing = root.type(typeContext);
+ root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
+ expression.setRoot(root);
+ }
+
+ private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model,
+ TypeContext<Reference> typeContext) {
+ if (node instanceof TensorFunctionNode) {
+ TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
+ if (tensorFunction instanceof Rename) {
+ List<ExpressionNode> children = ((TensorFunctionNode)node).children();
+ if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) children.get(0);
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(tensorFunction, typeContext);
+ }
+ }
+ }
+ }
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) node;
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
+ }
+ }
+ if (node instanceof CompositeNode) {
+ List<ExpressionNode> children = ((CompositeNode)node).children();
+ List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
+ for (ExpressionNode child : children) {
+ transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
+ }
+ return ((CompositeNode)node).setChildren(transformedChildren);
+ }
+ return node;
+ }
+
+ private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
+ TensorFunction result = function;
+ TensorType type = function.type(context);
+ if (type.dimensions().size() > 1) {
+ List<String> reduceDimensions = new ArrayList<>();
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1) {
+ reduceDimensions.add(dimension.name());
+ }
+ }
+ if (reduceDimensions.size() > 0) {
+ result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
+ }
+ }
+ return new TensorFunctionNode(result);
+ }
+
+ /**
+ * If batch dimensions have been reduced away above, bring them back here
+ * for any following computation of the tensor.
+ * Todo: determine when this is not necessary!
+ */
+ private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+ if (after.equals(before)) {
+ return node;
+ }
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : before.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
+ typeBuilder.indexed(dimension.name(), 1);
+ }
+ }
+ TensorType expandDimensionsType = typeBuilder.build();
+ if (expandDimensionsType.dimensions().size() > 0) {
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
+ Generate generatedFunction = new Generate(expandDimensionsType,
+ new GeneratorLambdaFunctionNode(expandDimensionsType,
+ generatedExpression)
+ .asLongListToDoubleOperator());
+ Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
+ return new TensorFunctionNode(expand);
+ }
+ return node;
+ }
+
+ /**
+ * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
+ * This method does that for the given expression and returns the result.
+ */
+ private RankingExpression replaceConstantsByMacros(RankingExpression expression,
+ Set<String> constantsReplacedByMacros) {
+ if (constantsReplacedByMacros.isEmpty()) return expression;
+ return new RankingExpression(expression.getName(),
+ replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
+ }
+
+ private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ if (node instanceof ReferenceNode) {
+ Reference reference = ((ReferenceNode)node).reference();
+ if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
+ String argument = reference.simpleArgument().get();
+ if (constantsReplacedByMacros.contains(argument))
+ return new ReferenceNode(argument);
+ }
+ }
+ if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
+ CompositeNode composite = (CompositeNode)node;
+ return composite.setChildren(composite.children().stream()
+ .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
+ .collect(Collectors.toList()));
+ }
+ return node;
+ }
+
+ private void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode)node;
+ if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
+ names.add(referenceNode.getName());
+ if (model.macros().containsKey(referenceNode.getName())) {
+ addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
+ }
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children())
+ addMacroNamesIn(child, names, model);
+ }
+ }
+
+ private Value asValue(Tensor tensor) {
+ if (tensor.type().rank() == 0)
+ return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
+ else
+ return new TensorValue(tensor);
+ }
+
+ /**
+ * Provides read/write access to the correct directories of the application package given by the feature arguments
+ */
+ static class ModelStore {
+
+ private final ApplicationPackage application;
+ private final FeatureArguments arguments;
+
+ ModelStore(ApplicationPackage application, FeatureArguments arguments) {
+ this.application = application;
+ this.arguments = arguments;
+ }
+
+ public FeatureArguments arguments() { return arguments; }
+
+ public boolean hasStoredModel() {
+ try {
+ return application.getFile(arguments.expressionPath()).exists();
+ }
+ catch (UnsupportedOperationException e) {
+ return false;
+ }
+ }
+
+ /**
+ * Returns the directory which contains the source model to use for these arguments
+ */
+ public File modelDir() {
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
+ }
+
+ /**
+ * Adds this expression to the application package, such that it can be read later.
+ */
+ void writeConverted(RankingExpression expression) {
+ application.getFile(arguments.expressionPath())
+ .writeFile(new StringReader(expression.getRoot().toString()));
+ }
+
+ /** Reads the previously stored ranking expression for these arguments */
+ RankingExpression readConverted() {
+ try {
+ return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+
+ /** Adds this macro expression to the application package to it can be read later. */
+ void writeMacro(String name, RankingExpression expression) {
+ application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
+ expression.getRoot().toString() + "\n");
+ }
+
+ /** Reads the previously stored macro expressions for these arguments */
+ List<Pair<String, RankingExpression>> readMacros() {
+ try {
+ ApplicationFile file = application.getFile(arguments.macrosPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, RankingExpression>> macros = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ try {
+ RankingExpression expression = new RankingExpression(parts[1]);
+ macros.add(new Pair<>(name, expression));
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+ return macros;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Reads the information about all the large (aka ranking) constants stored in the application package
+ * (the constant value itself is replicated with file distribution).
+ */
+ List<RankingConstant> readLargeConstants() {
+ try {
+ List<RankingConstant> constants = new ArrayList<>();
+ for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
+ String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
+ constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Adds this constant to the application package as a file,
+ * such that it can be distributed using file distribution.
+ *
+ * @return the path to the stored constant, relative to the application package root
+ */
+ Path writeLargeConstant(String name, Tensor constant) {
+ Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
+
+ // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
+ Path constantPath = constantsPath.append(name + ".tbf");
+
+ // Remember the constant in a file we replicate in ZooKeeper
+ application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
+ .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
+
+ // Write content explicitly as a file on the file system as this is distributed using file distribution
+ createIfNeeded(constantsPath);
+ IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
+ return correct(constantPath);
+ }
+
+ private List<Pair<String, Tensor>> readSmallConstants() {
+ try {
+ ApplicationFile file = application.getFile(arguments.smallConstantsPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, Tensor>> constants = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ TensorType type = TensorType.fromSpec(parts[1]);
+ Tensor tensor = Tensor.from(type, parts[2]);
+ constants.add(new Pair<>(name, tensor));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Append this constant to the single file used for small constants distributed as config
+ */
+ public void writeSmallConstant(String name, Tensor constant) {
+ // Secret file format for remembering constants:
+ application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
+ constant.type().toString() + "\t" +
+ constant.toString() + "\n");
+ }
+
+ /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
+ private Path correct(Path path) {
+ if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
+ && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
+ return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
+ }
+ else {
+ return path;
+ }
+ }
+
+ private void createIfNeeded(Path path) {
+ File dir = application.getFileReference(path);
+ if ( ! dir.exists()) {
+ if (!dir.mkdirs())
+ throw new IllegalStateException("Could not create " + dir);
+ }
+ }
+
+ }
+
+ /** Encapsulates the arguments to the import feature */
+ static abstract class FeatureArguments {
+
+ Path modelPath;
+
+ /** Optional arguments */
+ Optional<String> signature, output;
+
+ /** Returns modelPath with slashes replaced by underscores */
+ public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
+
+ /** Returns relative path to this model below the "models/" dir in the application package */
+ public Path modelPath() { return modelPath; }
+ public Optional<String> signature() { return signature; }
+ public Optional<String> output() { return output; }
+
+ /** Path to the small constants file */
+ public Path smallConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
+ }
+
+ /** Path to the large (ranking) constants directory */
+ public Path largeConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
+ }
+
+ /** Path to the macros file */
+ public Path macrosPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
+ }
+
+ public Path expressionPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
+ .append(modelPath).append("expressions").append(expressionFileName());
+ }
+
+ private String expressionFileName() {
+ StringBuilder fileName = new StringBuilder();
+ signature.ifPresent(s -> fileName.append(s).append("."));
+ output.ifPresent(s -> fileName.append(s).append("."));
+ if (fileName.length() == 0) // single signature and output
+ fileName.append("single.");
+ fileName.append("expression");
+ return fileName.toString();
+ }
+
+ Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
+ if (argumentIndex >= arguments.expressions().size())
+ return Optional.empty();
+ return Optional.of(asString(arguments.expressions().get(argumentIndex)));
+ }
+
+ String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
+ }
+
+ private String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
+ }
+
+ private boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
+ }
+
+ }
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index 1c41ad8284e..44eeb364603 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -2,58 +2,20 @@
package com.yahoo.searchdefinition.expressiontransforms;
-import com.google.common.base.Joiner;
-import com.yahoo.collections.Pair;
-import com.yahoo.config.application.api.ApplicationFile;
-import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.model.application.provider.FilesApplicationPackage;
-import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchdefinition.FeatureNames;
import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchdefinition.RankingConstant;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxImporter;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxModel;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.tensor.serialization.TypedBinaryFormat;
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.IOException;
-import java.io.StringReader;
import java.io.UncheckedIOException;
-import java.util.ArrayList;
-import java.util.Collections;
import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
import java.util.Map;
import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
/**
* Replaces instances of the onnx(model-path, output)
@@ -63,12 +25,12 @@ import java.util.stream.Collectors;
* @author bratseth
* @author lesters
*/
-public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
+public class OnnxFeatureConverter extends MLImportFeatureConverter {
private final OnnxImporter onnxImporter = new OnnxImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<Path, OnnxModel> importedModels = new HashMap<>();
+ private final Map<Path, ImportedModel> importedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -84,7 +46,8 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
if ( ! feature.getName().equals("onnx")) return feature;
try {
- ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
+ FeatureArguments arguments = new OnnxFeatureArguments(feature.getArguments());
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
if ( ! store.hasStoredModel()) // not converted yet - access Onnx model files
return transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles());
else
@@ -98,597 +61,24 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
private ExpressionNode transformFromOnnxModel(ModelStore store,
RankProfile profile,
QueryProfileRegistry queryProfiles) {
- OnnxModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
+ ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
k -> onnxImporter.importModel(store.arguments().modelName(),
- store.onnxModelDir()));
-
- // Add constants
- Set<String> constantsReplacedByMacros = new HashSet<>();
- model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
- model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
- constantsReplacedByMacros, k, v));
-
- // Find the specified expression
- String output = chooseOutput(model, store.arguments().output());
- if (model.skippedOutputs().containsKey(output)) {
- String message = "Could not import Onnx model output '" + output + "'";
- if (!model.skippedOutputs().get(output).isEmpty()) {
- message += ": " + model.skippedOutputs().get(output);
- }
- if (!model.importWarnings().isEmpty()) {
- message += ": " + String.join(", ", model.importWarnings());
- }
- throw new IllegalArgumentException(message);
- }
-
- RankingExpression expression = model.expressions().get(output);
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- verifyRequiredMacros(expression, model, profile, queryProfiles);
- addGeneratedMacros(model, profile);
- reduceBatchDimensions(expression, model, profile, queryProfiles);
-
- model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v));
-
- store.writeConverted(expression);
- return expression.getRoot();
- }
-
- private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
- for (Pair<String, Tensor> constant : store.readSmallConstants())
- profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
-
- for (RankingConstant constant : store.readLargeConstants()) {
- if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
- profile.getSearch().addRankingConstant(constant);
- }
-
- for (Pair<String, RankingExpression> macro : store.readMacros()) {
- addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
- }
-
- return store.readConverted().getRoot();
- }
-
- /**
- * Returns the specified, existing output expression, or the only output expression if no output name is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private String chooseOutput(OnnxModel model, Optional<String> outputName) {
- if ( ! outputName.isPresent()) {
- if (model.outputs().size() == 0)
- throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(model));
- if (model.outputs().size() > 1)
- throw new IllegalArgumentException("Onnx model has multiple outputs (" +
- Joiner.on(", ").join(model.outputs().keySet()) +
- "), one must be specified " +
- "as a second argument to onnx()");
- return model.outputs().get(model.outputs().keySet().stream().findFirst().get());
- }
- else {
- String output = model.outputs().get(outputName.get());
- if (output == null) {
- if (model.skippedOutputs().containsKey(outputName.get()))
- throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
- model.skippedOutputs().get(outputName.get()));
- else
- throw new IllegalArgumentException("Model does not have the specified output '" +
- outputName.get() + "'");
- }
- return output;
- }
+ store.modelDir()));
+ return transformFromImportedModel(model, store, profile, queryProfiles);
}
- private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
- store.writeSmallConstant(constantName, constantValue);
- profile.addConstant(constantName, asValue(constantValue));
- }
-
- private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
- Set<String> constantsReplacedByMacros,
- String constantName, Tensor constantValue) {
- RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
- if (macroOverridingConstant != null) {
- TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
- if ( ! macroType.equals(constantValue.type()))
- throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
- "The required type of this is " + constantValue.type() +
- ", but the macro returns " + macroType);
- constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
- }
- else {
- Path constantPath = store.writeLargeConstant(constantName, constantValue);
- if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
- profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
- constantPath.toString()));
- }
- }
- }
-
- private void transformGeneratedMacro(ModelStore store, RankProfile profile,
- Set<String> constantsReplacedByMacros,
- String macroName, RankingExpression expression) {
-
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- store.writeMacro(macroName, expression);
- }
-
- private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
- if (profile.getMacros().containsKey(macroName)) {
- throw new IllegalArgumentException("Generated Onnx macro '" + macroName + "' already exists.");
- }
- profile.addMacro(macroName, false); // todo: inline if only used once
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- macro.setRankingExpression(expression);
- macro.setTextualExpression(expression.getRoot().toString());
- }
-
- private String skippedOutputsDescription(OnnxModel model) {
- if (model.skippedOutputs().isEmpty()) return "";
- StringBuilder b = new StringBuilder(": ");
- model.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
- return b.toString();
- }
-
- /**
- * Verify that the macros referred in the given expression exists in the given rank profile,
- * and return tensors of the types specified in requiredMacros.
- */
- private void verifyRequiredMacros(RankingExpression expression, OnnxModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- TensorType requiredType = model.requiredMacros().get(macroName);
- if (requiredType == null) continue; // Not a required macro
-
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null)
- throw new IllegalArgumentException("Model refers Placeholder '" + macroName +
- "' of type " + requiredType + " but this macro is not present in " +
- profile);
- // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
- // phase and summary features), as it may only resolve correctly given those bindings
- // Or, probably better, annotate the macros with type constraints here and verify during general
- // type verification
- TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
- if ( actualType == null)
- throw new IllegalArgumentException("Model refers input '" + macroName +
- "' of type " + requiredType +
- " which must be produced by a macro in the rank profile, but " +
- "this macro references a feature which is not declared");
- if ( ! actualType.isAssignableTo(requiredType))
- throw new IllegalArgumentException("Model refers input '" + macroName +
- "' of type " + requiredType +
- " which must be produced by a macro in the rank profile, but " +
- "this macro produces type " + actualType);
- }
- }
-
- /**
- * Add the generated macros to the rank profile
- */
- private void addGeneratedMacros(OnnxModel model, RankProfile profile) {
- model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
- }
-
- /**
- * Check if batch dimensions of inputs can be reduced out. If the input
- * macro specifies that a single exemplar should be evaluated, we can
- * reduce the batch dimension out.
- */
- private void reduceBatchDimensions(RankingExpression expression, OnnxModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
- TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
-
- // Check generated macros for inputs to reduce
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- if ( ! model.macros().containsKey(macroName)) {
- continue;
- }
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null) {
- throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
- "but this macro is not present in " + profile);
- }
- RankingExpression macroExpression = macro.getRankingExpression();
- macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
- }
-
- // Check expression for inputs to reduce
- ExpressionNode root = expression.getRoot();
- root = reduceBatchDimensionsAtInput(root, model, typeContext);
- TensorType typeAfterReducing = root.type(typeContext);
- root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
- expression.setRoot(root);
- }
-
- private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, OnnxModel model,
- TypeContext<Reference> typeContext) {
- if (node instanceof TensorFunctionNode) {
- TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
- if (tensorFunction instanceof Rename) {
- List<ExpressionNode> children = ((TensorFunctionNode)node).children();
- if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(tensorFunction, typeContext);
- }
- }
- }
- }
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
- }
- }
- if (node instanceof CompositeNode) {
- List<ExpressionNode> children = ((CompositeNode)node).children();
- List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
- for (ExpressionNode child : children) {
- transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
- }
- return ((CompositeNode)node).setChildren(transformedChildren);
- }
- return node;
- }
-
- private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
- TensorFunction result = function;
- TensorType type = function.type(context);
- if (type.dimensions().size() > 1) {
- List<String> reduceDimensions = new ArrayList<>();
- for (TensorType.Dimension dimension : type.dimensions()) {
- if (dimension.size().orElse(-1L) == 1) {
- reduceDimensions.add(dimension.name());
- }
- }
- if (reduceDimensions.size() > 0) {
- result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
- }
- }
- return new TensorFunctionNode(result);
- }
-
- /**
- * If batch dimensions have been reduced away above, bring them back here
- * for any following computation of the tensor.
- * Todo: determine when this is not necessary!
- */
- private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) {
- return node;
- }
- TensorType.Builder typeBuilder = new TensorType.Builder();
- for (TensorType.Dimension dimension : before.dimensions()) {
- if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
- typeBuilder.indexed(dimension.name(), 1);
- }
- }
- TensorType expandDimensionsType = typeBuilder.build();
- if (expandDimensionsType.dimensions().size() > 0) {
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
- Generate generatedFunction = new Generate(expandDimensionsType,
- new GeneratorLambdaFunctionNode(expandDimensionsType,
- generatedExpression)
- .asLongListToDoubleOperator());
- Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
- return new TensorFunctionNode(expand);
- }
- return node;
- }
-
- /**
- * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
- * This method does that for the given expression and returns the result.
- */
- private RankingExpression replaceConstantsByMacros(RankingExpression expression,
- Set<String> constantsReplacedByMacros) {
- if (constantsReplacedByMacros.isEmpty()) return expression;
- return new RankingExpression(expression.getName(),
- replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
- }
-
- private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
- if (node instanceof ReferenceNode) {
- Reference reference = ((ReferenceNode)node).reference();
- if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
- String argument = reference.simpleArgument().get();
- if (constantsReplacedByMacros.contains(argument))
- return new ReferenceNode(argument);
- }
- }
- if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
- CompositeNode composite = (CompositeNode)node;
- return composite.setChildren(composite.children().stream()
- .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
- .collect(Collectors.toList()));
- }
- return node;
- }
-
- private void addMacroNamesIn(ExpressionNode node, Set<String> names, OnnxModel model) {
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode)node;
- if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
- names.add(referenceNode.getName());
- if (model.macros().containsKey(referenceNode.getName())) {
- addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
- }
- }
- }
- else if (node instanceof CompositeNode) {
- for (ExpressionNode child : ((CompositeNode)node).children())
- addMacroNamesIn(child, names, model);
- }
- }
-
- private Value asValue(Tensor tensor) {
- if (tensor.type().rank() == 0)
- return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
- else
- return new TensorValue(tensor);
- }
-
- /**
- * Provides read/write access to the correct directories of the application package given by the feature arguments
- */
- private static class ModelStore {
-
- private final ApplicationPackage application;
- private final FeatureArguments arguments;
-
- public ModelStore(ApplicationPackage application, Arguments arguments) {
- this.application = application;
- this.arguments = new FeatureArguments(arguments);
- }
-
- public FeatureArguments arguments() { return arguments; }
-
- public boolean hasStoredModel() {
- try {
- return application.getFile(arguments.expressionPath()).exists();
- }
- catch (UnsupportedOperationException e) {
- return false;
- }
- }
-
- /**
- * Returns the directory which contains the source model to use for these arguments
- */
- public File onnxModelDir() {
- return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
- }
-
- /**
- * Adds this expression to the application package, such that it can be read later.
- */
- public void writeConverted(RankingExpression expression) {
- application.getFile(arguments.expressionPath())
- .writeFile(new StringReader(expression.getRoot().toString()));
- }
-
- /** Reads the previously stored ranking expression for these arguments */
- public RankingExpression readConverted() {
- try {
- return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
- }
- catch (IOException e) {
- throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
-
- /** Adds this macro expression to the application package to it can be read later. */
- public void writeMacro(String name, RankingExpression expression) {
- application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
- expression.getRoot().toString() + "\n");
- }
-
- /** Reads the previously stored macro expressions for these arguments */
- public List<Pair<String, RankingExpression>> readMacros() {
- try {
- ApplicationFile file = application.getFile(arguments.macrosPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, RankingExpression>> macros = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- try {
- RankingExpression expression = new RankingExpression(parts[1]);
- macros.add(new Pair<>(name, expression));
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
- return macros;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Reads the information about all the large (aka ranking) constants stored in the application package
- * (the constant value itself is replicated with file distribution).
- */
- public List<RankingConstant> readLargeConstants() {
- try {
- List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
- String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
- constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Adds this constant to the application package as a file,
- * such that it can be distributed using file distribution.
- *
- * @return the path to the stored constant, relative to the application package root
- */
- public Path writeLargeConstant(String name, Tensor constant) {
- Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
-
- // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
- Path constantPath = constantsPath.append(name + ".tbf");
-
- // Remember the constant in a file we replicate in ZooKeeper
- application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
- .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
-
- // Write content explicitly as a file on the file system as this is distributed using file distribution
- createIfNeeded(constantsPath);
- IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
- return correct(constantPath);
- }
-
- private List<Pair<String, Tensor>> readSmallConstants() {
- try {
- ApplicationFile file = application.getFile(arguments.smallConstantsPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, Tensor>> constants = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- TensorType type = TensorType.fromSpec(parts[1]);
- Tensor tensor = Tensor.from(type, parts[2]);
- constants.add(new Pair<>(name, tensor));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Append this constant to the single file used for small constants distributed as config
- */
- public void writeSmallConstant(String name, Tensor constant) {
- // Secret file format for remembering constants:
- application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
- constant.type().toString() + "\t" +
- constant.toString() + "\n");
- }
-
- /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
- private Path correct(Path path) {
- if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
- && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
- return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
- }
- else {
- return path;
- }
- }
-
- private void createIfNeeded(Path path) {
- File dir = application.getFileReference(path);
- if ( ! dir.exists()) {
- if (!dir.mkdirs())
- throw new IllegalStateException("Could not create " + dir);
- }
- }
-
- }
-
- /** Encapsulates the 1, 2 or 3 arguments to a onnx feature */
- private static class FeatureArguments {
-
- private final Path modelPath;
-
- /** Optional arguments */
- private final Optional<String> output;
-
- public FeatureArguments(Arguments arguments) {
+ static class OnnxFeatureArguments extends FeatureArguments {
+ public OnnxFeatureArguments(Arguments arguments) {
if (arguments.isEmpty())
throw new IllegalArgumentException("An onnx node must take an argument pointing to " +
- "the onnx model directory under [application]/models");
+ "the tensorflow model directory under [application]/models");
if (arguments.expressions().size() > 3)
throw new IllegalArgumentException("An onnx feature can have at most 2 arguments");
modelPath = Path.fromString(asString(arguments.expressions().get(0)));
output = optionalArgument(1, arguments);
+ signature = Optional.of("default");
}
-
- /** Returns modelPath with slashes replaced by underscores */
- public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
-
- /** Returns relative path to this model below the "models/" dir in the application package */
- public Path modelPath() { return modelPath; }
- public Optional<String> output() { return output; }
-
- /** Path to the small constants file */
- public Path smallConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
- }
-
- /** Path to the large (ranking) constants directory */
- public Path largeConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
- }
-
- /** Path to the macros file */
- public Path macrosPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
- }
-
- public Path expressionPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
- .append(modelPath).append("expressions").append(expressionFileName());
- }
-
- private String expressionFileName() {
- StringBuilder fileName = new StringBuilder();
- output.ifPresent(s -> fileName.append(s).append("."));
- if (fileName.length() == 0) // single signature and output
- fileName.append("single.");
- fileName.append("expression");
- return fileName.toString();
- }
-
- private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
- if (argumentIndex >= arguments.expressions().size())
- return Optional.empty();
- return Optional.of(asString(arguments.expressions().get(argumentIndex)));
- }
-
- private String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as onnx argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
- }
-
- private String stripQuotes(String s) {
- if ( ! isQuoteSign(s.codePointAt(0))) return s;
- if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("onnx argument [" + s + "] is missing endquote");
- return s.substring(1, s.length()-1);
- }
-
- private boolean isQuoteSign(int c) {
- return c == '\'' || c == '"';
- }
-
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 41da32f64c3..27e1ad51b33 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -1,59 +1,19 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
-import com.google.common.base.Joiner;
-import com.yahoo.collections.Pair;
-import com.yahoo.config.application.api.ApplicationFile;
-import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.model.application.provider.FilesApplicationPackage;
-import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchdefinition.FeatureNames;
import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchdefinition.RankingConstant;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.tensor.serialization.TypedBinaryFormat;
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.IOException;
-import java.io.StringReader;
import java.io.UncheckedIOException;
-import java.util.ArrayList;
-import java.util.Collections;
import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
import java.util.Map;
-import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
/**
* Replaces instances of the tensorflow(model-path, signature, output)
@@ -62,12 +22,12 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
-public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
+public class TensorFlowFeatureConverter extends MLImportFeatureConverter {
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<Path, TensorFlowModel> importedModels = new HashMap<>();
+ private final Map<Path, ImportedModel> importedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -83,7 +43,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
if ( ! feature.getName().equals("tensorflow")) return feature;
try {
- ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
+ FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments());
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files
return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles());
else
@@ -95,565 +56,19 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
private ExpressionNode transformFromTensorFlowModel(ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles) {
- TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
- k -> tensorFlowImporter.importModel(store.arguments().modelName(),
- store.tensorFlowModelDir()));
-
- // Add constants
- Set<String> constantsReplacedByMacros = new HashSet<>();
- model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
- model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
- constantsReplacedByMacros, k, v));
-
- // Find the specified expression
- Signature signature = chooseSignature(model, store.arguments().signature());
- String output = chooseOutput(signature, store.arguments().output());
- if (signature.skippedOutputs().containsKey(output)) {
- String message = "Could not import TensorFlow model output '" + output + "'";
- if (!signature.skippedOutputs().get(output).isEmpty()) {
- message += ": " + signature.skippedOutputs().get(output);
- }
- if (!signature.importWarnings().isEmpty()) {
- message += ": " + String.join(", ", signature.importWarnings());
- }
- throw new IllegalArgumentException(message);
- }
-
- RankingExpression expression = model.expressions().get(output);
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- verifyRequiredMacros(expression, model, profile, queryProfiles);
- addGeneratedMacros(model, profile);
- reduceBatchDimensions(expression, model, profile, queryProfiles);
-
- model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
-
- store.writeConverted(expression);
- return expression.getRoot();
- }
-
- private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
- for (Pair<String, Tensor> constant : store.readSmallConstants())
- profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
-
- for (RankingConstant constant : store.readLargeConstants()) {
- if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
- profile.getSearch().addRankingConstant(constant);
- }
-
- for (Pair<String, RankingExpression> macro : store.readMacros()) {
- addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
- }
-
- return store.readConverted().getRoot();
- }
-
- /**
- * Returns the specified, existing signature, or the only signature if none is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) {
- if ( ! signatureName.isPresent()) {
- if (importResult.signatures().size() == 0)
- throw new IllegalArgumentException("No signatures are available");
- if (importResult.signatures().size() > 1)
- throw new IllegalArgumentException("Model has multiple signatures (" +
- Joiner.on(", ").join(importResult.signatures().keySet()) +
- "), one must be specified " +
- "as a second argument to tensorflow()");
- return importResult.signatures().values().stream().findFirst().get();
- }
- else {
- Signature signature = importResult.signatures().get(signatureName.get());
- if (signature == null)
- throw new IllegalArgumentException("Model does not have the specified signature '" +
- signatureName.get() + "'");
- return signature;
- }
- }
-
- /**
- * Returns the specified, existing output expression, or the only output expression if no output name is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private String chooseOutput(Signature signature, Optional<String> outputName) {
- if ( ! outputName.isPresent()) {
- if (signature.outputs().size() == 0)
- throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature));
- if (signature.outputs().size() > 1)
- throw new IllegalArgumentException(signature + " has multiple outputs (" +
- Joiner.on(", ").join(signature.outputs().keySet()) +
- "), one must be specified " +
- "as a third argument to tensorflow()");
- return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get());
- }
- else {
- String output = signature.outputs().get(outputName.get());
- if (output == null) {
- if (signature.skippedOutputs().containsKey(outputName.get()))
- throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
- signature.skippedOutputs().get(outputName.get()));
- else
- throw new IllegalArgumentException("Model does not have the specified output '" +
- outputName.get() + "'");
- }
- return output;
- }
- }
-
- private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
- store.writeSmallConstant(constantName, constantValue);
- profile.addConstant(constantName, asValue(constantValue));
- }
-
- private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
- Set<String> constantsReplacedByMacros,
- String constantName, Tensor constantValue) {
- RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
- if (macroOverridingConstant != null) {
- TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
- if ( ! macroType.equals(constantValue.type()))
- throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
- typeMismatchExplanation(constantValue.type(), macroType));
- constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
- }
- else {
- Path constantPath = store.writeLargeConstant(constantName, constantValue);
- if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
- profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
- constantPath.toString()));
- }
- }
- }
-
- private void transformGeneratedMacro(ModelStore store,
- Set<String> constantsReplacedByMacros,
- String macroName, RankingExpression expression) {
-
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- store.writeMacro(macroName, expression);
- }
-
- private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
- if (profile.getMacros().containsKey(macroName)) {
- throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists.");
- }
- profile.addMacro(macroName, false); // todo: inline if only used once
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- macro.setRankingExpression(expression);
- macro.setTextualExpression(expression.getRoot().toString());
- }
-
- private String skippedOutputsDescription(TensorFlowModel.Signature signature) {
- if (signature.skippedOutputs().isEmpty()) return "";
- StringBuilder b = new StringBuilder(": ");
- signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
- return b.toString();
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles) {
+ ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
+ k -> tensorFlowImporter.importModel(store.arguments().modelName(),
+ store.modelDir()));
+ return transformFromImportedModel(model, store, profile, queryProfiles);
}
- /**
- * Verify that the macros referred in the given expression exists in the given rank profile,
- * and return tensors of the types specified in requiredMacros.
- */
- private void verifyRequiredMacros(RankingExpression expression, TensorFlowModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- TensorType requiredType = model.requiredMacros().get(macroName);
- if (requiredType == null) continue; // Not a required macro
-
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null)
- throw new IllegalArgumentException("Model refers placeholder '" + macroName +
- "' of type " + requiredType + " but this macro is not present in " +
- profile);
- // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
- // phase and summary features), as it may only resolve correctly given those bindings
- // Or, probably better, annotate the macros with type constraints here and verify during general
- // type verification
- TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
- if ( actualType == null)
- throw new IllegalArgumentException("Model refers placeholder '" + macroName +
- "' of type " + requiredType +
- " which must be produced by a macro in the rank profile, but " +
- "this macro references a feature which is not declared");
- if ( ! actualType.isAssignableTo(requiredType))
- throw new IllegalArgumentException("Model refers placeholder '" + macroName + "'. " +
- typeMismatchExplanation(requiredType, actualType));
- }
- }
-
- private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
- return "The required type of this is " + requiredType + ", but this macro returns " + actualType +
- (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " +
- "in query profile types - see the documentation."
- : "");
- }
-
- /**
- * Add the generated macros to the rank profile
- */
- private void addGeneratedMacros(TensorFlowModel model, RankProfile profile) {
- model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
- }
-
- /**
- * Check if batch dimensions of inputs can be reduced out. If the input
- * macro specifies that a single exemplar should be evaluated, we can
- * reduce the batch dimension out.
- */
- private void reduceBatchDimensions(RankingExpression expression, TensorFlowModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
- TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
-
- // Check generated macros for inputs to reduce
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- if ( ! model.macros().containsKey(macroName)) {
- continue;
- }
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null) {
- throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
- "but this macro is not present in " + profile);
- }
- RankingExpression macroExpression = macro.getRankingExpression();
- macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
- }
-
- // Check expression for inputs to reduce
- ExpressionNode root = expression.getRoot();
- root = reduceBatchDimensionsAtInput(root, model, typeContext);
- TensorType typeAfterReducing = root.type(typeContext);
- root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
- expression.setRoot(root);
- }
-
- private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model,
- TypeContext<Reference> typeContext) {
- if (node instanceof TensorFunctionNode) {
- TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
- if (tensorFunction instanceof Rename) {
- List<ExpressionNode> children = ((TensorFunctionNode)node).children();
- if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(tensorFunction, typeContext);
- }
- }
- }
- }
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
- }
- }
- if (node instanceof CompositeNode) {
- List<ExpressionNode> children = ((CompositeNode)node).children();
- List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
- for (ExpressionNode child : children) {
- transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
- }
- return ((CompositeNode)node).setChildren(transformedChildren);
- }
- return node;
- }
-
- private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
- TensorFunction result = function;
- TensorType type = function.type(context);
- if (type.dimensions().size() > 1) {
- List<String> reduceDimensions = new ArrayList<>();
- for (TensorType.Dimension dimension : type.dimensions()) {
- if (dimension.size().orElse(-1L) == 1) {
- reduceDimensions.add(dimension.name());
- }
- }
- if (reduceDimensions.size() > 0) {
- result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
- }
- }
- return new TensorFunctionNode(result);
- }
-
- /**
- * If batch dimensions have been reduced away above, bring them back here
- * for any following computation of the tensor.
- * Todo: determine when this is not necessary!
- */
- private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) {
- return node;
- }
- TensorType.Builder typeBuilder = new TensorType.Builder();
- for (TensorType.Dimension dimension : before.dimensions()) {
- if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
- typeBuilder.indexed(dimension.name(), 1);
- }
- }
- TensorType expandDimensionsType = typeBuilder.build();
- if (expandDimensionsType.dimensions().size() > 0) {
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
- Generate generatedFunction = new Generate(expandDimensionsType,
- new GeneratorLambdaFunctionNode(expandDimensionsType,
- generatedExpression)
- .asLongListToDoubleOperator());
- Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
- return new TensorFunctionNode(expand);
- }
- return node;
- }
-
- /**
- * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
- * This method does that for the given expression and returns the result.
- */
- private RankingExpression replaceConstantsByMacros(RankingExpression expression,
- Set<String> constantsReplacedByMacros) {
- if (constantsReplacedByMacros.isEmpty()) return expression;
- return new RankingExpression(expression.getName(),
- replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
- }
-
- private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
- if (node instanceof ReferenceNode) {
- Reference reference = ((ReferenceNode)node).reference();
- if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
- String argument = reference.simpleArgument().get();
- if (constantsReplacedByMacros.contains(argument))
- return new ReferenceNode(argument);
- }
- }
- if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
- CompositeNode composite = (CompositeNode)node;
- return composite.setChildren(composite.children().stream()
- .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
- .collect(Collectors.toList()));
- }
- return node;
- }
-
- private void addMacroNamesIn(ExpressionNode node, Set<String> names, TensorFlowModel model) {
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode)node;
- if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
- names.add(referenceNode.getName());
- if (model.macros().containsKey(referenceNode.getName())) {
- addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
- }
- }
- }
- else if (node instanceof CompositeNode) {
- for (ExpressionNode child : ((CompositeNode)node).children())
- addMacroNamesIn(child, names, model);
- }
- }
-
- private Value asValue(Tensor tensor) {
- if (tensor.type().rank() == 0)
- return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
- else
- return new TensorValue(tensor);
- }
-
- /**
- * Provides read/write access to the correct directories of the application package given by the feature arguments
- */
- private static class ModelStore {
-
- private final ApplicationPackage application;
- private final FeatureArguments arguments;
-
- public ModelStore(ApplicationPackage application, Arguments arguments) {
- this.application = application;
- this.arguments = new FeatureArguments(arguments);
- }
-
-
-
- public FeatureArguments arguments() { return arguments; }
-
- public boolean hasStoredModel() {
- try {
- return application.getFile(arguments.expressionPath()).exists();
- }
- catch (UnsupportedOperationException e) {
- return false;
- }
- }
-
- /**
- * Returns the directory which (if hasTensorFlowModels is true)
- * contains the source model to use for these arguments
- */
- public File tensorFlowModelDir() {
- return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
- }
-
- /**
- * Adds this expression to the application package, such that it can be read later.
- */
- public void writeConverted(RankingExpression expression) {
- application.getFile(arguments.expressionPath())
- .writeFile(new StringReader(expression.getRoot().toString()));
- }
-
- /** Reads the previously stored ranking expression for these arguments */
- public RankingExpression readConverted() {
- try {
- return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
- }
- catch (IOException e) {
- throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
-
- /** Adds this macro expression to the application package to it can be read later. */
- public void writeMacro(String name, RankingExpression expression) {
- application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
- expression.getRoot().toString() + "\n");
- }
-
- /** Reads the previously stored macro expressions for these arguments */
- public List<Pair<String, RankingExpression>> readMacros() {
- try {
- ApplicationFile file = application.getFile(arguments.macrosPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, RankingExpression>> macros = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- try {
- RankingExpression expression = new RankingExpression(parts[1]);
- macros.add(new Pair<>(name, expression));
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
- return macros;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Reads the information about all the large (aka ranking) constants stored in the application package
- * (the constant value itself is replicated with file distribution).
- */
- public List<RankingConstant> readLargeConstants() {
- try {
- List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
- String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
- constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Adds this constant to the application package as a file,
- * such that it can be distributed using file distribution.
- *
- * @return the path to the stored constant, relative to the application package root
- */
- public Path writeLargeConstant(String name, Tensor constant) {
- Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
-
- // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
- Path constantPath = constantsPath.append(name + ".tbf");
-
- // Remember the constant in a file we replicate in ZooKeeper
- application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
- .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
-
- // Write content explicitly as a file on the file system as this is distributed using file distribution
- createIfNeeded(constantsPath);
- IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
- return correct(constantPath);
- }
-
- private List<Pair<String, Tensor>> readSmallConstants() {
- try {
- ApplicationFile file = application.getFile(arguments.smallConstantsPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, Tensor>> constants = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- TensorType type = TensorType.fromSpec(parts[1]);
- Tensor tensor = Tensor.from(type, parts[2]);
- constants.add(new Pair<>(name, tensor));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Append this constant to the single file used for small constants distributed as config
- */
- public void writeSmallConstant(String name, Tensor constant) {
- // Secret file format for remembering constants:
- application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
- constant.type().toString() + "\t" +
- constant.toString() + "\n");
- }
-
- /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
- private Path correct(Path path) {
- if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
- && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
- return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
- }
- else {
- return path;
- }
- }
-
- private void createIfNeeded(Path path) {
- File dir = application.getFileReference(path);
- if ( ! dir.exists()) {
- if (!dir.mkdirs())
- throw new IllegalStateException("Could not create " + dir);
- }
- }
-
- }
-
- /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */
- private static class FeatureArguments {
-
- private final Path modelPath;
-
- /** Optional arguments */
- private final Optional<String> signature, output;
-
- public FeatureArguments(Arguments arguments) {
+ static class TensorFlowFeatureArguments extends FeatureArguments {
+ public TensorFlowFeatureArguments(Arguments arguments) {
if (arguments.isEmpty())
throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
- "the tensorflow model directory under [application]/models");
+ "the tensorflow model directory under [application]/models");
if (arguments.expressions().size() > 3)
throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments");
@@ -661,68 +76,6 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
signature = optionalArgument(1, arguments);
output = optionalArgument(2, arguments);
}
-
- /** Returns modelPath with slashes replaced by underscores */
- public String modelName() { return modelPath.toString().replace('/', '_'); }
-
- /** Returns relative path to this model below the "models/" dir in the application package */
- public Path modelPath() { return modelPath; }
- public Optional<String> signature() { return signature; }
- public Optional<String> output() { return output; }
-
- /** Path to the small constants file */
- public Path smallConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
- }
-
- /** Path to the large (ranking) constants directory */
- public Path largeConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
- }
-
- /** Path to the macros file */
- public Path macrosPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
- }
-
- public Path expressionPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
- .append(modelPath).append("expressions").append(expressionFileName());
- }
-
- private String expressionFileName() {
- StringBuilder fileName = new StringBuilder();
- signature.ifPresent(s -> fileName.append(s).append("."));
- output.ifPresent(s -> fileName.append(s).append("."));
- if (fileName.length() == 0) // single signature and output
- fileName.append("single.");
- fileName.append("expression");
- return fileName.toString();
- }
-
- private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
- if (argumentIndex >= arguments.expressions().size())
- return Optional.empty();
- return Optional.of(asString(arguments.expressions().get(argumentIndex)));
- }
-
- private String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
- }
-
- private String stripQuotes(String s) {
- if ( ! isQuoteSign(s.codePointAt(0))) return s;
- if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote");
- return s.substring(1, s.length()-1);
- }
-
- private boolean isQuoteSign(int c) {
- return c == '\'' || c == '"';
- }
-
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/Host.java b/config-model/src/main/java/com/yahoo/vespa/model/Host.java
index 0adfe9e4bdb..624a9fd4da7 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/Host.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/Host.java
@@ -42,16 +42,14 @@ public final class Host extends AbstractConfigProducer<AbstractConfigProducer<?>
private void checkName(HostSystem parent, String hostname) {
// Give a warning if the host does not exist
- // Host exists - warn if given hostname is not a fully qualified one.
- String canonical = hostname;
try {
- canonical = parent.getCanonicalHostname(hostname);
+ Object address = java.net.InetAddress.getByName(hostname);
} catch (UnknownHostException e) {
- deployLogger().log(Level.WARNING, "Unable to find canonical hostname of host: " + hostname);
+ deployLogger().log(Level.WARNING, "Unable to lookup IP address of host: " + hostname);
}
- if ((null != canonical) && (! hostname.equals(canonical))) {
+ if (! hostname.contains(".")) {
deployLogger().log(Level.WARNING, "Host named '" + hostname + "' may not receive any config " +
- "since it does not match its canonical hostname: " + canonical);
+ "since it is not a canonical hostname");
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java
index 6467199d9f9..fc46ed18dde 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java
@@ -110,6 +110,13 @@ public class VespaMetricSet {
metrics.add(new Metric("jdisc.memory_mappings.max"));
metrics.add(new Metric("jdisc.open_file_descriptors.max"));
+ metrics.add(new Metric("jdisc.gc.count.average"));
+ metrics.add(new Metric("jdisc.gc.count.max"));
+ metrics.add(new Metric("jdisc.gc.count.last"));
+ metrics.add(new Metric("jdisc.gc.ms.average"));
+ metrics.add(new Metric("jdisc.gc.ms.max"));
+ metrics.add(new Metric("jdisc.gc.ms.last"));
+
metrics.add(new Metric("jdisc.deactivated_containers.total.last"));
metrics.add(new Metric("jdisc.deactivated_containers.with_retained_refs.last"));
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 1c54d12d8b3..d9beab6e2f2 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -37,15 +37,6 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testOnnxReference() throws ParseException {
- RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx')");
- search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
- assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
- }
-
- @Test
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"onnx('mnist_softmax.onnx')",
@@ -122,13 +113,6 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testOnnxReferenceSpecifyingOutput() {
- RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'add')");
- search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- }
-
- @Test
public void testOnnxReferenceMissingMacro() throws ParseException {
try {
RankProfileSearchFixture search = new RankProfileSearchFixture(
@@ -145,7 +129,7 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
+ "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
@@ -163,8 +147,8 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " +
- "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])",
+ "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
+ "but this macro returns tensor(d0[2],d5[10])",
Exceptions.toMessageString(expected));
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index d288a396732..7228af2b0de 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -162,7 +162,7 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved'): " +
- "Model refers placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
+ "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
@@ -179,7 +179,7 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved'): " +
- "Model refers placeholder 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
+ "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
"but this macro returns tensor(d0[2],d5[10])",
Exceptions.toMessageString(expected));
}
@@ -305,9 +305,9 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testMacroGeneration() {
- final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))";
- final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
"tensorflow('mnist/saved')",
@@ -316,15 +316,15 @@ public class RankingExpressionWithTensorFlowTestCase {
"input",
new StoringApplicationPackage(applicationDir));
search.assertFirstPhaseExpression(expression, "my_profile");
- search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
}
@Test
public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
- final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))";
- final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
@@ -335,8 +335,8 @@ public class RankingExpressionWithTensorFlowTestCase {
application);
search.assertFirstPhaseExpression(expression, "my_profile");
assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
- search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
@@ -353,8 +353,8 @@ public class RankingExpressionWithTensorFlowTestCase {
storedApplication);
searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
- searchFromStored.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- searchFromStored.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ searchFromStored.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ searchFromStored.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
@@ -465,7 +465,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
- public static class StoringApplicationPackageFile extends ApplicationFile {
+ static class StoringApplicationPackageFile extends ApplicationFile {
/** The path to the application package root */
private final Path root;
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ClusterSizeReductionValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ClusterSizeReductionValidatorTest.java
index 765acf9e27b..4c3583ba0ae 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ClusterSizeReductionValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ClusterSizeReductionValidatorTest.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.application.validation.change;
+import com.yahoo.config.application.api.ValidationId;
+import com.yahoo.config.application.api.ValidationOverrides;
import com.yahoo.config.model.api.ConfigChangeAction;
import com.yahoo.config.model.api.ConfigChangeRefeedAction;
import com.yahoo.vespa.model.VespaModel;
@@ -33,7 +35,8 @@ public class ClusterSizeReductionValidatorTest {
fail("Expected exception due to cluster size reduction");
}
catch (IllegalArgumentException expected) {
- assertEquals("cluster-size-reduction: Size reduction in 'default' is too large. Current size: 30, new size: 14. New size must be at least 50% of the current size",
+ assertEquals("cluster-size-reduction: Size reduction in 'default' is too large. Current size: 30, new size: 14. New size must be at least 50% of the current size. " +
+ ValidationOverrides.toAllowMessage(ValidationId.clusterSizeReduction),
Exceptions.toMessageString(expected));
}
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java
index 25ad6dbc620..ee58ca67b02 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.application.validation.change;
+import com.yahoo.config.application.api.ValidationId;
+import com.yahoo.config.application.api.ValidationOverrides;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.application.validation.ValidationTester;
import com.yahoo.yolean.Exceptions;
@@ -24,7 +26,8 @@ public class ContentClusterRemovalValidatorTest {
fail("Expected exception due to content cluster id change");
}
catch (IllegalArgumentException expected) {
- assertEquals("content-cluster-removal: Content cluster 'contentClusterId' is removed. This will cause loss of all data in this cluster",
+ assertEquals("content-cluster-removal: Content cluster 'contentClusterId' is removed. This will cause loss of all data in this cluster. " +
+ ValidationOverrides.toAllowMessage(ValidationId.contentClusterRemoval),
Exceptions.toMessageString(expected));
}
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentTypeRemovalValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentTypeRemovalValidatorTest.java
index a52c6d7c7a2..ca45520711e 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentTypeRemovalValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentTypeRemovalValidatorTest.java
@@ -1,6 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.application.validation.change;
+import com.yahoo.config.application.api.ValidationId;
+import com.yahoo.config.application.api.ValidationOverrides;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.application.validation.ValidationTester;
import com.yahoo.yolean.Exceptions;
@@ -28,7 +30,8 @@ public class ContentTypeRemovalValidatorTest {
}
catch (IllegalArgumentException expected) {
assertEquals("content-type-removal: Type 'music' is removed in content cluster 'test'. " +
- "This will cause loss of all data of this type",
+ "This will cause loss of all data of this type. " +
+ ValidationOverrides.toAllowMessage(ValidationId.contentTypeRemoval),
Exceptions.toMessageString(expected));
}
}
diff --git a/config/src/main/java/com/yahoo/config/subscription/impl/ConfigSubscription.java b/config/src/main/java/com/yahoo/config/subscription/impl/ConfigSubscription.java
index 76241c560e4..2b6f9e24dd3 100644
--- a/config/src/main/java/com/yahoo/config/subscription/impl/ConfigSubscription.java
+++ b/config/src/main/java/com/yahoo/config/subscription/impl/ConfigSubscription.java
@@ -204,7 +204,7 @@ public abstract class ConfigSubscription<T extends ConfigInstance> {
void setInternalRedeploy(boolean internalRedeploy) {
ConfigState<T> prev = config.get();
- this.config.set(new ConfigState<>(prev.isGenerationChanged(), prev.getGeneration(), prev.isConfigChanged(), internalRedeploy, prev.getConfig()));
+ this.config.set(new ConfigState<>(prev.isGenerationChanged(), prev.getGeneration(), internalRedeploy, prev.isConfigChanged(), prev.getConfig()));
}
/**
diff --git a/config/src/main/java/com/yahoo/config/subscription/impl/JRTConfigRequester.java b/config/src/main/java/com/yahoo/config/subscription/impl/JRTConfigRequester.java
index 88af414e28d..243c9e932a8 100644
--- a/config/src/main/java/com/yahoo/config/subscription/impl/JRTConfigRequester.java
+++ b/config/src/main/java/com/yahoo/config/subscription/impl/JRTConfigRequester.java
@@ -30,10 +30,10 @@ import com.yahoo.vespa.config.protocol.Trace;
* as context, and puts the requests objects on a queue on the subscription,
* for handling by the user thread.
*
- * @author vegardh
- * @since 5.1
+ * @author Vegard Havdal
*/
public class JRTConfigRequester implements RequestWaiter {
+
private static final Logger log = Logger.getLogger(JRTConfigRequester.class.getName());
public static final ConfigSourceSet defaultSourceSet = ConfigSourceSet.createDefault();
private static final int TRACELEVEL = 6;
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java
index 4c32e635391..e9d400591e8 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java
@@ -23,7 +23,7 @@ import com.yahoo.path.Path;
import com.yahoo.slime.Slime;
import com.yahoo.transaction.NestedTransaction;
import com.yahoo.vespa.config.server.application.Application;
-import com.yahoo.vespa.config.server.application.ApplicationConvergenceChecker;
+import com.yahoo.vespa.config.server.application.ConfigConvergenceChecker;
import com.yahoo.vespa.config.server.application.ApplicationSet;
import com.yahoo.vespa.config.server.application.FileDistributionStatus;
import com.yahoo.vespa.config.server.application.HttpProxy;
@@ -88,7 +88,7 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye
private final TenantRepository tenantRepository;
private final Optional<Provisioner> hostProvisioner;
- private final ApplicationConvergenceChecker convergeChecker;
+ private final ConfigConvergenceChecker convergeChecker;
private final HttpProxy httpProxy;
private final Clock clock;
private final DeployLogger logger = new SilentDeployLogger();
@@ -99,22 +99,22 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye
@Inject
public ApplicationRepository(TenantRepository tenantRepository,
HostProvisionerProvider hostProvisionerProvider,
- ApplicationConvergenceChecker applicationConvergenceChecker,
+ ConfigConvergenceChecker configConvergenceChecker,
HttpProxy httpProxy,
ConfigserverConfig configserverConfig) {
this(tenantRepository, hostProvisionerProvider.getHostProvisioner(),
- applicationConvergenceChecker, httpProxy, configserverConfig, Clock.systemUTC(), new FileDistributionStatus());
+ configConvergenceChecker, httpProxy, configserverConfig, Clock.systemUTC(), new FileDistributionStatus());
}
// For testing
public ApplicationRepository(TenantRepository tenantRepository,
Provisioner hostProvisioner,
Clock clock) {
- this(tenantRepository, new ApplicationConvergenceChecker(), hostProvisioner, clock);
+ this(tenantRepository, new ConfigConvergenceChecker(), hostProvisioner, clock);
}
public ApplicationRepository(TenantRepository tenantRepository,
- ApplicationConvergenceChecker convergenceChecker,
+ ConfigConvergenceChecker convergenceChecker,
Provisioner hostProvisioner,
Clock clock) {
this(tenantRepository, Optional.of(hostProvisioner),
@@ -124,14 +124,14 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye
private ApplicationRepository(TenantRepository tenantRepository,
Optional<Provisioner> hostProvisioner,
- ApplicationConvergenceChecker applicationConvergenceChecker,
+ ConfigConvergenceChecker configConvergenceChecker,
HttpProxy httpProxy,
ConfigserverConfig configserverConfig,
Clock clock,
FileDistributionStatus fileDistributionStatus) {
this.tenantRepository = tenantRepository;
this.hostProvisioner = hostProvisioner;
- this.convergeChecker = applicationConvergenceChecker;
+ this.convergeChecker = configConvergenceChecker;
this.httpProxy = httpProxy;
this.clock = clock;
this.configserverConfig = configserverConfig;
@@ -373,12 +373,12 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye
// ---------------- Convergence ----------------------------------------------------------------
- public HttpResponse serviceConvergenceCheck(ApplicationId applicationId, String hostname, URI uri) {
- return convergeChecker.serviceConvergenceCheck(getApplication(applicationId), hostname, uri);
+ public HttpResponse checkServiceForConfigConvergence(ApplicationId applicationId, String hostAndPort, URI uri) {
+ return convergeChecker.checkService(getApplication(applicationId), hostAndPort, uri);
}
- public HttpResponse serviceListToCheckForConfigConvergence(ApplicationId applicationId, URI uri) {
- return convergeChecker.serviceListToCheckForConfigConvergence(getApplication(applicationId), uri);
+ public HttpResponse servicesToCheckForConfigConvergence(ApplicationId applicationId, URI uri) {
+ return convergeChecker.servicesToCheck(getApplication(applicationId), uri);
}
// ---------------- Session operations ----------------------------------------------------------------
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerBootstrap.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerBootstrap.java
index 9793a441355..916fde97e35 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerBootstrap.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerBootstrap.java
@@ -3,56 +3,70 @@ package com.yahoo.vespa.config.server;
import com.google.inject.Inject;
import com.yahoo.component.AbstractComponent;
+import com.yahoo.concurrent.DaemonThreadFactory;
+import com.yahoo.container.handler.VipStatus;
import com.yahoo.container.jdisc.state.StateMonitor;
import com.yahoo.log.LogLevel;
import com.yahoo.vespa.config.server.rpc.RpcServer;
import com.yahoo.vespa.config.server.version.VersionState;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
/**
* Main component that bootstraps and starts config server threads.
*
- * @author lulf
- * @since 5.1
+ * If config server has been upgraded to a new version since the last time it was running it will redeploy all
+ * applications. If that is done successfully the RPC server will start and the health status code will change from
+ * 'initializing' to 'up' and the config server will be put into rotation (start serving status.html with 200 OK)
+ *
+ * @author Ulf Lilleengen
+ * @author hmusum
*/
public class ConfigServerBootstrap extends AbstractComponent implements Runnable {
private static final java.util.logging.Logger log = java.util.logging.Logger.getLogger(ConfigServerBootstrap.class.getName());
+ private static final ExecutorService rpcServerExecutor = Executors.newSingleThreadExecutor(new DaemonThreadFactory("config server RPC server"));
+ private static final String vipStatusClusterIdentifier = "configserver";
private final ApplicationRepository applicationRepository;
private final RpcServer server;
private final Thread serverThread;
private final VersionState versionState;
private final StateMonitor stateMonitor;
+ private final VipStatus vipStatus;
// The tenants object is injected so that all initial requests handlers are
// added to the rpc server before it starts answering rpc requests.
@SuppressWarnings("WeakerAccess")
@Inject
public ConfigServerBootstrap(ApplicationRepository applicationRepository, RpcServer server,
- VersionState versionState, StateMonitor stateMonitor) {
- this(applicationRepository, server, versionState, stateMonitor, true);
+ VersionState versionState, StateMonitor stateMonitor, VipStatus vipStatus) {
+ this(applicationRepository, server, versionState, stateMonitor, vipStatus, true);
}
// For testing only
- ConfigServerBootstrap(ApplicationRepository applicationRepository, RpcServer server,
- VersionState versionState, StateMonitor stateMonitor, boolean startMainThread) {
+ ConfigServerBootstrap(ApplicationRepository applicationRepository, RpcServer server, VersionState versionState,
+ StateMonitor stateMonitor, VipStatus vipStatus, boolean startMainThread) {
this.applicationRepository = applicationRepository;
this.server = server;
this.versionState = versionState;
this.stateMonitor = stateMonitor;
this.serverThread = new Thread(this, "configserver main");
+ this.vipStatus = vipStatus;
+ initializing(); // Initially take server out of rotation
if (startMainThread)
start();
}
- private void start() {
- serverThread.start();
- }
-
@Override
public void deconstruct() {
log.log(LogLevel.INFO, "Stopping config server");
+ down();
server.stop();
+ rpcServerExecutor.shutdown();
try {
serverThread.join();
} catch (InterruptedException e) {
@@ -74,9 +88,8 @@ public class ConfigServerBootstrap extends AbstractComponent implements Runnable
return; // Status will not be set to 'up' since we return here
}
}
- stateMonitor.status(StateMonitor.Status.up);
- log.log(LogLevel.INFO, "Starting RPC server");
- server.run();
+ startRpcServer();
+ up();
do {
try {
Thread.sleep(1000);
@@ -85,13 +98,51 @@ public class ConfigServerBootstrap extends AbstractComponent implements Runnable
break;
}
} while (server.isRunning());
+ down();
log.log(LogLevel.INFO, "RPC server stopped");
- stateMonitor.status(StateMonitor.Status.down);
}
StateMonitor.Status status() {
return stateMonitor.status();
}
+ private void start() {
+ serverThread.start();
+ }
+
+ private void up() {
+ stateMonitor.status(StateMonitor.Status.up);
+ vipStatus.addToRotation(vipStatusClusterIdentifier);
+ }
+
+ private void down() {
+ stateMonitor.status(StateMonitor.Status.down);
+ vipStatus.removeFromRotation(vipStatusClusterIdentifier);
+ }
+
+ private void initializing() {
+ // This is default value (from config), so not strictly necessary
+ stateMonitor.status(StateMonitor.Status.initializing);
+ vipStatus.removeFromRotation(vipStatusClusterIdentifier);
+ }
+
+ private void startRpcServer() {
+ log.log(LogLevel.INFO, "Starting RPC server");
+ rpcServerExecutor.execute(server);
+
+ Instant end = Instant.now().plus(Duration.ofSeconds(10));
+ while (!server.isRunning() && Instant.now().isBefore(end)) {
+ try {
+ Thread.sleep(10);
+ } catch (InterruptedException e) {
+ log.log(LogLevel.ERROR, "Got interrupted", e);
+ break;
+ }
+ }
+ if (!server.isRunning())
+ throw new RuntimeException("RPC server not started in 10 seconds");
+ log.log(LogLevel.INFO, "RPC server started");
+ }
+
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java
index 58168a7526f..4978f5f274d 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java
@@ -29,11 +29,10 @@ import java.util.stream.Collectors;
/**
* Checks for convergence of config generation for a given application.
*
- * @author lulf
+ * @author Ulf Lilleengen
* @author hmusum
*/
-public class ApplicationConvergenceChecker extends AbstractComponent {
-
+public class ConfigConvergenceChecker extends AbstractComponent {
private static final String statePath = "/state/v1/";
private static final String configSubPath = "config";
private final static Set<String> serviceTypesToCheck = new HashSet<>(Arrays.asList(
@@ -49,15 +48,15 @@ public class ApplicationConvergenceChecker extends AbstractComponent {
private final Client client = ClientBuilder.newClient();
@Inject
- public ApplicationConvergenceChecker() {
- this(ApplicationConvergenceChecker::createStateApi);
+ public ConfigConvergenceChecker() {
+ this(ConfigConvergenceChecker::createStateApi);
}
- public ApplicationConvergenceChecker(StateApiFactory stateApiFactory) {
+ public ConfigConvergenceChecker(StateApiFactory stateApiFactory) {
this.stateApiFactory = stateApiFactory;
}
- public ServiceListResponse serviceListToCheckForConfigConvergence(Application application, URI uri) {
+ public ServiceListResponse servicesToCheck(Application application, URI uri) {
List<ServiceInfo> servicesToCheck = new ArrayList<>();
application.getModel().getHosts()
.forEach(host -> host.getServices().stream()
@@ -69,7 +68,7 @@ public class ApplicationConvergenceChecker extends AbstractComponent {
currentGeneration);
}
- public ServiceResponse serviceConvergenceCheck(Application application, String hostAndPortToCheck, URI uri) {
+ public ServiceResponse checkService(Application application, String hostAndPortToCheck, URI uri) {
Long wantedGeneration = application.getApplicationGeneration();
try {
if (! hostInApplication(application, hostAndPortToCheck))
@@ -157,8 +156,7 @@ public class ApplicationConvergenceChecker extends AbstractComponent {
return false;
}
- static class ServiceListResponse extends JSONResponse {
- final Cursor debug;
+ private static class ServiceListResponse extends JSONResponse {
// Pre-condition: servicesToCheck has a state port
private ServiceListResponse(int status, List<ServiceInfo> servicesToCheck, URI uri, long wantedGeneration,
@@ -178,40 +176,28 @@ public class ApplicationConvergenceChecker extends AbstractComponent {
object.setLong("currentGeneration", currentGeneration);
object.setLong("wantedGeneration", wantedGeneration);
object.setBool("converged", currentGeneration >= wantedGeneration);
- // TODO: Remove debug when clients are not using it anymore
- debug = object.setObject("debug");
- debug.setLong("wantedGeneration", wantedGeneration);
}
}
static class ServiceResponse extends JSONResponse {
- final Cursor debug;
private ServiceResponse(int status, URI uri, String hostname, Long wantedGeneration) {
super(status);
object.setString("url", uri.toString());
object.setString("host", hostname);
object.setLong("wantedGeneration", wantedGeneration);
- // TODO: Remove debug when clients are not using it anymore
- debug = object.setObject("debug");
- debug.setString("host", hostname);
- debug.setLong("wantedGeneration", wantedGeneration);
}
static ServiceResponse createOkResponse(URI uri, String hostname, Long wantedGeneration, Long currentGeneration, boolean converged) {
ServiceResponse serviceResponse = new ServiceResponse(200, uri, hostname, wantedGeneration);
serviceResponse.object.setBool("converged", converged);
serviceResponse.object.setLong("currentGeneration", currentGeneration);
- // TODO: Remove debug when clients are not using it anymore
- serviceResponse.debug.setLong("currentGeneration", currentGeneration);
return serviceResponse;
}
static ServiceResponse createHostNotFoundInAppResponse(URI uri, String hostname, Long wantedGeneration) {
ServiceResponse serviceResponse = new ServiceResponse(410, uri, hostname, wantedGeneration);
serviceResponse.object.setString("problem", "Host:port (service) no longer part of application, refetch list of services.");
- // TODO: Remove debug when clients are not using it anymore
- serviceResponse.debug.setString("problem", "Host:port (service) no longer part of application, refetch list of services.");
return serviceResponse;
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionImpl.java
index 2db89c2e8ed..36d76bbfc79 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionImpl.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionImpl.java
@@ -10,7 +10,6 @@ import com.yahoo.jrt.Spec;
import com.yahoo.jrt.StringArray;
import com.yahoo.jrt.Supervisor;
import com.yahoo.jrt.Target;
-import com.yahoo.jrt.Transport;
import com.yahoo.log.LogLevel;
import com.yahoo.vespa.defaults.Defaults;
@@ -24,11 +23,12 @@ import java.util.logging.Logger;
public class FileDistributionImpl implements FileDistribution {
private final static Logger log = Logger.getLogger(FileDistributionImpl.class.getName());
- private final Supervisor supervisor = new Supervisor(new Transport());
+ private final Supervisor supervisor;
private final File fileReferencesDir;
- public FileDistributionImpl(ConfigserverConfig configserverConfig) {
+ public FileDistributionImpl(ConfigserverConfig configserverConfig, Supervisor supervisor) {
this.fileReferencesDir = new File(Defaults.getDefaults().underVespaHome(configserverConfig.fileReferencesDir()));
+ this.supervisor = supervisor;
}
@Override
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java
index 6bca8b1c562..473ec913f50 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java
@@ -59,7 +59,9 @@ public class ApplicationHandler extends HttpHandler {
Tenant tenant = verifyTenantAndApplication(applicationId);
if (isServiceConvergeRequest(request)) {
- return applicationRepository.serviceConvergenceCheck(applicationId, getHostNameFromRequest(request), request.getUri());
+ // Expects both hostname and port in the request (hostname:port)
+ String hostAndPort = getHostNameFromRequest(request);
+ return applicationRepository.checkServiceForConfigConvergence(applicationId, hostAndPort, request.getUri());
}
if (isClusterControllerStatusRequest(request)) {
@@ -86,7 +88,7 @@ public class ApplicationHandler extends HttpHandler {
}
if (isServiceConvergeListRequest(request)) {
- return applicationRepository.serviceListToCheckForConfigConvergence(applicationId, request.getUri());
+ return applicationRepository.servicesToCheckForConfigConvergence(applicationId, request.getUri());
}
if (isFiledistributionStatusRequest(request)) {
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java
index c6a390caf86..2a53f9ee45c 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java
@@ -51,7 +51,7 @@ public class ConfigServerMaintenance extends AbstractComponent {
this.defaultInterval = Duration.ofMinutes(configserverConfig.maintainerIntervalMinutes());
// TODO: Want job control or feature flag to control when to run this, for now use a very
// long interval to avoid running the maintainer
- this.tenantsMaintainerInterval = isCd || isTest
+ this.tenantsMaintainerInterval = isCd || isTest || configserverConfig.region().equals("us-central-1")
? defaultInterval
: Duration.ofMinutes(configserverConfig.tenantsMaintainerIntervalMinutes());
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/FileDistributionMaintainer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/FileDistributionMaintainer.java
index 2664a0bde8c..1d16283d938 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/FileDistributionMaintainer.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/FileDistributionMaintainer.java
@@ -31,9 +31,10 @@ public class FileDistributionMaintainer extends Maintainer {
@Override
protected void maintain() {
- // TODO: For now only deletes files in CD system
+ // TODO: Delete files in all zones
boolean deleteFiles = (SystemName.from(configserverConfig.system()) == SystemName.cd)
- || Environment.from(configserverConfig.environment()).isTest();
+ || Environment.from(configserverConfig.environment()).isTest()
+ || configserverConfig.region().equals("us-central-1");
applicationRepository.deleteUnusedFiledistributionReferences(fileReferencesDir, deleteFiles);
}
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/RpcServer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/RpcServer.java
index 9de587ac17b..f1cf479a38a 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/RpcServer.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/RpcServer.java
@@ -166,7 +166,7 @@ public class RpcServer implements Runnable, ReloadListener, TenantListener {
}
public void run() {
- log.log(LogLevel.INFO, "Rpc server listening on port " + spec.port());
+ log.log(LogLevel.INFO, "Rpc will listen on port " + spec.port());
try {
Acceptor acceptor = supervisor.listen(spec);
isRunning = true;
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java
index 8394494adca..15bc3c1fb46 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java
@@ -3,6 +3,8 @@ package com.yahoo.vespa.config.server.session;
import com.google.inject.Inject;
import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.jrt.Supervisor;
+import com.yahoo.jrt.Transport;
import com.yahoo.vespa.config.server.filedistribution.FileDistributionImpl;
import com.yahoo.vespa.config.server.filedistribution.FileDistributionProvider;
@@ -17,6 +19,7 @@ import java.io.File;
public class FileDistributionFactory {
private final ConfigserverConfig configserverConfig;
+ private final Supervisor supervisor = new Supervisor(new Transport());
@Inject
public FileDistributionFactory(ConfigserverConfig configserverConfig) {
@@ -24,7 +27,7 @@ public class FileDistributionFactory {
}
public FileDistributionProvider createProvider(File applicationPackage) {
- return new FileDistributionProvider(applicationPackage, new FileDistributionImpl(configserverConfig));
+ return new FileDistributionProvider(applicationPackage, new FileDistributionImpl(configserverConfig, supervisor));
}
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/Tenant.java b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/Tenant.java
index 3557d7bf9ab..8d21a1b8c03 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/Tenant.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/Tenant.java
@@ -11,7 +11,6 @@ import com.yahoo.vespa.config.server.session.LocalSessionRepo;
import com.yahoo.vespa.config.server.session.RemoteSessionRepo;
import com.yahoo.vespa.config.server.session.SessionFactory;
import com.yahoo.vespa.curator.Curator;
-import org.apache.zookeeper.Op;
import org.apache.zookeeper.data.Stat;
import java.time.Instant;
diff --git a/configserver/src/main/resources/configserver-app/services.xml b/configserver/src/main/resources/configserver-app/services.xml
index 2eeefda63e7..79797854689 100644
--- a/configserver/src/main/resources/configserver-app/services.xml
+++ b/configserver/src/main/resources/configserver-app/services.xml
@@ -10,6 +10,10 @@
<initialStatus>initializing</initialStatus>
</config>
+ <config name="container.core.vip-status">
+ <initiallyInRotation>false</initiallyInRotation>
+ </config>
+
<accesslog type="vespa" fileNamePattern="logs/vespa/configserver/access.log.%Y%m%d%H%M%S" rotationScheme="date" compressOnRotation="true" symlinkName="access.log" />
<preprocess:include file='access-logging.xml' required='false' />
@@ -37,7 +41,7 @@
<component id="com.yahoo.vespa.config.server.host.ConfigRequestHostLivenessTracker" bundle="configserver" />
<component id="com.yahoo.container.jdisc.metric.state.StateMetricConsumerFactory" bundle="container-disc" />
<component id="com.yahoo.config.provision.Zone" bundle="config-provisioning" />
- <component id="com.yahoo.vespa.config.server.application.ApplicationConvergenceChecker" bundle="configserver" />
+ <component id="com.yahoo.vespa.config.server.application.ConfigConvergenceChecker" bundle="configserver" />
<component id="com.yahoo.vespa.config.server.application.HttpProxy" bundle="configserver" />
<component id="com.yahoo.vespa.config.server.filedistribution.FileServer" bundle="configserver" />
<component id="com.yahoo.vespa.config.server.maintenance.ConfigServerMaintenance" bundle="configserver" />
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/ConfigServerBootstrapTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/ConfigServerBootstrapTest.java
index 67cc87ae223..992d46d3115 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/ConfigServerBootstrapTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/ConfigServerBootstrapTest.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.config.server;
import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.container.handler.VipStatus;
import com.yahoo.container.jdisc.config.HealthMonitorConfig;
import com.yahoo.container.jdisc.state.StateMonitor;
import com.yahoo.jdisc.core.SystemTimer;
@@ -43,13 +44,17 @@ public class ConfigServerBootstrapTest {
assertTrue(versionState.isUpgraded());
RpcServer rpcServer = createRpcServer(configserverConfig);
- ConfigServerBootstrap bootstrap = new ConfigServerBootstrap(tester.applicationRepository(), rpcServer, versionState, createStateMonitor());
- waitUntil(() -> bootstrap.status() == StateMonitor.Status.up, "failed waiting for status 'up'");
+ VipStatus vipStatus = new VipStatus();
+ ConfigServerBootstrap bootstrap = new ConfigServerBootstrap(tester.applicationRepository(), rpcServer, versionState, createStateMonitor(), vipStatus);
+ assertFalse(vipStatus.isInRotation());
waitUntil(rpcServer::isRunning, "failed waiting for Rpc server running");
+ waitUntil(() -> bootstrap.status() == StateMonitor.Status.up, "failed waiting for status 'up'");
+ waitUntil(vipStatus::isInRotation, "failed waiting for server to be in rotation");
bootstrap.deconstruct();
assertEquals(StateMonitor.Status.down, bootstrap.status());
assertFalse(rpcServer.isRunning());
+ assertFalse(vipStatus.isInRotation());
}
@Test
@@ -69,13 +74,17 @@ public class ConfigServerBootstrapTest {
.resolve("sessions/2/services.xml"));
RpcServer rpcServer = createRpcServer(configserverConfig);
+ VipStatus vipStatus = new VipStatus();
ConfigServerBootstrap bootstrap = new ConfigServerBootstrap(tester.applicationRepository(), rpcServer, versionState,
- createStateMonitor(), false /* do not call run method */);
+ createStateMonitor(), vipStatus, false /* do not call run method */);
+ assertFalse(vipStatus.isInRotation());
// Call method directly, to be sure that it is finished redeploying all applications and we can check status
bootstrap.run();
- // App is invalid, so bootstrapping was unsuccessful. Status should be 'initializing' and rpc server should not be running
+ // App is invalid, bootstrapping was unsuccessful. Status should be 'initializing',
+ // rpc server should not be running and it should be out of rotation
assertEquals(StateMonitor.Status.initializing, bootstrap.status());
assertFalse(rpcServer.isRunning());
+ assertFalse(vipStatus.isInRotation());
}
private void waitUntil(BooleanSupplier booleanSupplier, String messageIfWaitingFails) throws InterruptedException {
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/TestWithCurator.java b/configserver/src/test/java/com/yahoo/vespa/config/server/TestWithCurator.java
deleted file mode 100644
index 5dce0607f90..00000000000
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/TestWithCurator.java
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.config.server;
-
-import com.yahoo.vespa.curator.Curator;
-import com.yahoo.vespa.curator.mock.MockCurator;
-import com.yahoo.vespa.config.server.zookeeper.ConfigCurator;
-import org.apache.curator.framework.CuratorFramework;
-import org.junit.Before;
-
-/**
- * For tests that require a Curator instance
- *
- * @author lulf
- * @since 5.16
- */
-public class TestWithCurator {
-
- protected ConfigCurator configCurator;
- protected CuratorFramework curatorFramework;
- protected Curator curator;
-
- @Before
- public void setupZKProvider() throws Exception {
- curator = new MockCurator();
- configCurator = ConfigCurator.create(curator);
- curatorFramework = curator.framework();
- }
-
-}
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceCheckerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/application/ConfigConvergenceCheckerTest.java
index 399169c122a..71052c8b463 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceCheckerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/application/ConfigConvergenceCheckerTest.java
@@ -25,21 +25,23 @@ import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
-import static com.yahoo.vespa.config.server.application.ApplicationConvergenceChecker.ServiceResponse;
+import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThat;
+import static com.yahoo.vespa.config.server.application.ConfigConvergenceChecker.ServiceResponse;
/**
* @author Ulf Lilleengen
*/
-public class ApplicationConvergenceCheckerTest {
+public class ConfigConvergenceCheckerTest {
private static final ObjectMapper mapper = new ObjectMapper();
private final TenantName tenant = TenantName.from("mytenant");
private final ApplicationId appId = ApplicationId.from(tenant, ApplicationName.from("myapp"), InstanceName.from("myinstance"));
private Application application;
+ private ConfigConvergenceChecker checker;
private Map<URI, Long> currentGeneration;
- private ApplicationConvergenceChecker checker;
@Rule
public TemporaryFolder folder = new TemporaryFolder();
@@ -54,7 +56,7 @@ public class ApplicationConvergenceCheckerTest {
Version.fromIntValues(0, 0, 0),
MetricUpdater.createTestUpdater(), appId);
currentGeneration = new HashMap<>();
- checker = new ApplicationConvergenceChecker(
+ checker = new ConfigConvergenceChecker(
(client, serviceUri) -> () -> asJson("{\"config\":{\"generation\":"
+ currentGeneration.getOrDefault(serviceUri, 3L)
+ "}}"));
@@ -62,37 +64,27 @@ public class ApplicationConvergenceCheckerTest {
@Test
public void service_convergence() throws Exception {
- ServiceResponse serviceResponse = checker.serviceConvergenceCheck(application,
- "localhost:1337",
- URI.create("http://foo:234/serviceconverge/localhost:1337"));
+ ServiceResponse serviceResponse = checker.checkService(application,
+ "localhost:1337",
+ URI.create("http://foo:234/serviceconverge/localhost:1337"));
assertEquals(200, serviceResponse.getStatus());
assertJsonEquals("{\n" +
" \"url\": \"http://foo:234/serviceconverge/localhost:1337\",\n" +
" \"host\": \"localhost:1337\",\n" +
" \"wantedGeneration\": 3,\n" +
- " \"debug\": {\n" +
- " \"host\": \"localhost:1337\",\n" +
- " \"wantedGeneration\": 3,\n" +
- " \"currentGeneration\": 3\n" +
- " },\n" +
" \"converged\": true,\n" +
" \"currentGeneration\": 3\n" +
"}",
SessionHandlerTest.getRenderedString(serviceResponse));
- ServiceResponse hostMissingResponse = checker.serviceConvergenceCheck(application,
- "notPresent:1337",
- URI.create("http://foo:234/serviceconverge/notPresent:1337"));
+ ServiceResponse hostMissingResponse = checker.checkService(application,
+ "notPresent:1337",
+ URI.create("http://foo:234/serviceconverge/notPresent:1337"));
assertEquals(410, hostMissingResponse.getStatus());
assertJsonEquals("{\n" +
" \"url\": \"http://foo:234/serviceconverge/notPresent:1337\",\n" +
" \"host\": \"notPresent:1337\",\n" +
" \"wantedGeneration\": 3,\n" +
- " \"debug\": {\n" +
- " \"host\": \"notPresent:1337\",\n" +
- " \"wantedGeneration\": 3,\n" +
- " \"problem\": \"Host:port (service) no longer part of application, refetch list of services.\"\n" +
- " },\n" +
" \"problem\": \"Host:port (service) no longer part of application, refetch list of services.\"\n" +
"}",
SessionHandlerTest.getRenderedString(hostMissingResponse));
@@ -100,8 +92,7 @@ public class ApplicationConvergenceCheckerTest {
@Test
public void service_list_convergence() throws Exception {
- HttpResponse serviceListResponse = checker.serviceListToCheckForConfigConvergence(application,
- URI.create("http://foo:234/serviceconverge"));
+ HttpResponse serviceListResponse = checker.servicesToCheck(application, URI.create("http://foo:234/serviceconverge"));
assertEquals(200, serviceListResponse.getStatus());
assertJsonEquals("{\n" +
" \"services\": [\n" +
@@ -115,10 +106,7 @@ public class ApplicationConvergenceCheckerTest {
" \"url\": \"http://foo:234/serviceconverge\",\n" +
" \"currentGeneration\": 3,\n" +
" \"wantedGeneration\": 3,\n" +
- " \"converged\": true,\n" +
- " \"debug\": {\n" +
- " \"wantedGeneration\": 3\n" +
- " }\n" +
+ " \"converged\": true\n" +
"}",
SessionHandlerTest.getRenderedString(serviceListResponse));
@@ -132,7 +120,7 @@ public class ApplicationConvergenceCheckerTest {
Version.fromIntValues(0, 0, 0),
MetricUpdater.createTestUpdater(), appId);
currentGeneration.put(URI.create("http://host2:1234"), 4L);
- serviceListResponse = checker.serviceListToCheckForConfigConvergence(application, URI.create("http://foo:234/serviceconverge"));
+ serviceListResponse = checker.servicesToCheck(application, URI.create("http://foo:234/serviceconverge"));
assertEquals(200, serviceListResponse.getStatus());
assertJsonEquals("{\n" +
" \"services\": [\n" +
@@ -152,10 +140,7 @@ public class ApplicationConvergenceCheckerTest {
" \"url\": \"http://foo:234/serviceconverge\",\n" +
" \"currentGeneration\": 3,\n" +
" \"wantedGeneration\": 4,\n" +
- " \"converged\": false,\n" +
- " \"debug\": {\n" +
- " \"wantedGeneration\": 4\n" +
- " }\n" +
+ " \"converged\": false\n" +
"}",
SessionHandlerTest.getRenderedString(serviceListResponse));
}
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java
index a4bfb1de221..3fa1b3fdb5e 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java
@@ -5,9 +5,12 @@ import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.TenantName;
import com.yahoo.text.Utf8;
import com.yahoo.vespa.config.server.MockReloadHandler;
-import com.yahoo.vespa.config.server.TestWithCurator;
import com.yahoo.vespa.config.server.tenant.TenantRepository;
+import com.yahoo.vespa.curator.Curator;
+import com.yahoo.vespa.curator.mock.MockCurator;
+import org.apache.curator.framework.CuratorFramework;
+import org.junit.Before;
import org.junit.Test;
import java.util.Arrays;
@@ -19,12 +22,20 @@ import static org.junit.Assert.*;
/**
* @author Ulf Lilleengen
- * @since 5.1
*/
-public class TenantApplicationsTest extends TestWithCurator {
+public class TenantApplicationsTest {
private static final TenantName tenantName = TenantName.from("tenant");
+ private Curator curator;
+ private CuratorFramework curatorFramework;
+
+ @Before
+ public void setup() {
+ curator = new MockCurator();
+ curatorFramework = curator.framework();
+ }
+
@Test
public void require_that_applications_are_read_from_zookeeper() throws Exception {
writeApplicationData(createApplicationId("foo"), 3L);
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/deploy/ZooKeeperClientTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/deploy/ZooKeeperClientTest.java
index 945c7d60750..c4a0fd9f3f0 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/deploy/ZooKeeperClientTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/deploy/ZooKeeperClientTest.java
@@ -8,7 +8,6 @@ import com.yahoo.config.application.api.FileRegistry;
import com.yahoo.config.model.application.provider.*;
import com.yahoo.config.provision.*;
import com.yahoo.path.Path;
-import com.yahoo.vespa.config.server.TestWithCurator;
import com.yahoo.vespa.config.server.zookeeper.ZKApplicationPackage;
import com.yahoo.vespa.curator.mock.MockCurator;
import com.yahoo.vespa.config.server.zookeeper.ConfigCurator;
@@ -25,13 +24,12 @@ import java.util.*;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.*;
-
/**
* Unit tests for ZooKeeperClient.
*
* @author hmusum
*/
-public class ZooKeeperClientTest extends TestWithCurator {
+public class ZooKeeperClientTest {
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@@ -41,7 +39,7 @@ public class ZooKeeperClientTest extends TestWithCurator {
@Before
public void setupZK() throws IOException {
- this.zk = ConfigCurator.create(curator);
+ zk = ConfigCurator.create(new MockCurator());
ZooKeeperClient zkc = new ZooKeeperClient(zk, new BaseDeployLogger(), true, Path.fromString(appPath));
ApplicationPackage app = FilesApplicationPackage.fromFileWithDeployData(new File("src/test/apps/zkfeed"),
new DeployData("foo",
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java
index 2d39efb9013..d8c5e33ca65 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java
@@ -12,7 +12,7 @@ import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.jdisc.Response;
import com.yahoo.vespa.config.server.ApplicationRepository;
import com.yahoo.vespa.config.server.TestComponentRegistry;
-import com.yahoo.vespa.config.server.application.ApplicationConvergenceChecker;
+import com.yahoo.vespa.config.server.application.ConfigConvergenceChecker;
import com.yahoo.vespa.config.server.application.HttpProxy;
import com.yahoo.vespa.config.server.http.HandlerTest;
import com.yahoo.vespa.config.server.http.HttpErrorResponse;
@@ -69,7 +69,7 @@ public class ApplicationHandlerTest {
tenantRepository.addTenant(TenantBuilder.create(componentRegistry, foobar));
provisioner = new SessionHandlerTest.MockProvisioner();
applicationRepository = new ApplicationRepository(tenantRepository,
- new ApplicationConvergenceChecker(stateApiFactory),
+ new ConfigConvergenceChecker(stateApiFactory),
provisioner, Clock.systemUTC());
listApplicationsHandler = new ListApplicationsHandler(ListApplicationsHandler.testOnlyContext(),
tenantRepository,
@@ -163,7 +163,7 @@ public class ApplicationHandlerTest {
HttpProxy mockHttpProxy = mock(HttpProxy.class);
ApplicationRepository applicationRepository = new ApplicationRepository(tenantRepository,
HostProvisionerProvider.withProvisioner(provisioner),
- new ApplicationConvergenceChecker(stateApiFactory),
+ new ConfigConvergenceChecker(stateApiFactory),
mockHttpProxy,
new ConfigserverConfig(new ConfigserverConfig.Builder()));
ApplicationHandler mockHandler = createApplicationHandler(applicationRepository);
@@ -276,10 +276,10 @@ public class ApplicationHandlerTest {
return createApplicationHandler().handle(HttpRequest.createTestRequest(restartUrl, com.yahoo.jdisc.http.HttpRequest.Method.GET));
}
- private static class MockStateApiFactory implements ApplicationConvergenceChecker.StateApiFactory {
+ private static class MockStateApiFactory implements ConfigConvergenceChecker.StateApiFactory {
boolean createdApi = false;
@Override
- public ApplicationConvergenceChecker.StateApi createStateApi(Client client, URI serviceUri) {
+ public ConfigConvergenceChecker.StateApi createStateApi(Client client, URI serviceUri) {
createdApi = true;
return () -> {
try {
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionRepoTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionRepoTest.java
index 987dd8a6c4d..829dfb978b2 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionRepoTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionRepoTest.java
@@ -6,13 +6,13 @@ import com.yahoo.test.ManualClock;
import com.yahoo.config.provision.TenantName;
import com.yahoo.vespa.config.server.GlobalComponentRegistry;
import com.yahoo.vespa.config.server.TestComponentRegistry;
-import com.yahoo.vespa.config.server.TestWithCurator;
import com.yahoo.vespa.config.server.application.MemoryTenantApplications;
import com.yahoo.vespa.config.server.deploy.TenantFileSystemDirs;
import com.yahoo.io.IOUtils;
import com.yahoo.vespa.config.server.host.HostRegistry;
import com.yahoo.vespa.config.server.http.SessionHandlerTest;
+import com.yahoo.vespa.curator.mock.MockCurator;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -29,7 +29,7 @@ import static org.junit.Assert.fail;
/**
* @author Ulf Lilleengen
*/
-public class LocalSessionRepoTest extends TestWithCurator {
+public class LocalSessionRepoTest {
private File testApp = new File("src/test/apps/app");
private LocalSessionRepo repo;
@@ -45,7 +45,7 @@ public class LocalSessionRepoTest extends TestWithCurator {
}
private void setupSessions(TenantName tenantName, boolean createInitialSessions) throws Exception {
- GlobalComponentRegistry globalComponentRegistry = new TestComponentRegistry.Builder().curator(curator).build();
+ GlobalComponentRegistry globalComponentRegistry = new TestComponentRegistry.Builder().curator(new MockCurator()).build();
TenantFileSystemDirs tenantFileSystemDirs = new TenantFileSystemDirs(temporaryFolder.newFolder(), tenantName);
if (createInitialSessions) {
IOUtils.copyDirectory(testApp, new File(tenantFileSystemDirs.sessionsPath(), "1"));
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionRepoTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionRepoTest.java
index 88997d29572..4bbfea48254 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionRepoTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionRepoTest.java
@@ -13,12 +13,12 @@ import com.yahoo.text.Utf8;
import com.yahoo.transaction.Transaction;
import com.yahoo.vespa.config.server.TestComponentRegistry;
-import com.yahoo.vespa.config.server.TestWithCurator;
import com.yahoo.vespa.config.server.application.TenantApplications;
import com.yahoo.vespa.config.server.tenant.Tenant;
import com.yahoo.vespa.config.server.tenant.TenantBuilder;
import com.yahoo.vespa.config.server.tenant.TenantRepository;
import com.yahoo.vespa.curator.Curator;
+import com.yahoo.vespa.curator.mock.MockCurator;
import org.junit.Before;
import org.junit.Test;
@@ -32,16 +32,17 @@ import java.util.function.LongPredicate;
/**
* @author Ulf Lilleengen
- * @since 5.1
*/
-public class RemoteSessionRepoTest extends TestWithCurator {
+public class RemoteSessionRepoTest {
private static final TenantName tenantName = TenantName.defaultName();
private RemoteSessionRepo remoteSessionRepo;
+ private Curator curator;
@Before
- public void setupFacade() throws Exception {
+ public void setupFacade() {
+ curator = new MockCurator();
Tenant tenant = TenantBuilder.create(new TestComponentRegistry.Builder()
.curator(curator)
.build(),
@@ -75,7 +76,7 @@ public class RemoteSessionRepoTest extends TestWithCurator {
}
@Test
- public void testCreateSession() throws Exception {
+ public void testCreateSession() {
createSession(3l, true);
assertSessionExists(3l);
}
@@ -99,7 +100,7 @@ public class RemoteSessionRepoTest extends TestWithCurator {
// repo even if it had bad data (by making getSessionIdForApplication() in FailingTenantApplications
// throw an exception).
@Test
- public void testBadApplicationRepoOnActivate() throws Exception {
+ public void testBadApplicationRepoOnActivate() {
long sessionId = 3L;
TenantApplications applicationRepo = new FailingTenantApplications();
TenantName mytenant = TenantName.from("mytenant");
@@ -116,7 +117,7 @@ public class RemoteSessionRepoTest extends TestWithCurator {
private void assertStatusChange(long sessionId, Session.Status status) throws Exception {
Path statePath = TenantRepository.getSessionsPath(tenantName).append("" + sessionId).append(ConfigCurator.SESSIONSTATE_ZK_SUBPATH);
curator.create(statePath);
- curatorFramework.setData().forPath(statePath.getAbsolute(), Utf8.toBytes(status.toString()));
+ curator.framework().setData().forPath(statePath.getAbsolute(), Utf8.toBytes(status.toString()));
System.out.println("Setting status " + status + " for " + sessionId);
assertSessionStatus(sessionId, status);
}
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionTest.java
index 39fe27e5adb..b57d2d1a1a1 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionTest.java
@@ -22,9 +22,7 @@ import com.yahoo.vespa.model.VespaModelFactory;
import org.junit.Before;
import org.junit.Test;
-import org.xml.sax.SAXException;
-import java.io.IOException;
import java.time.Clock;
import java.time.Instant;
import java.time.LocalDate;
@@ -42,8 +40,7 @@ import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
/**
- * @author lulf
- * @since 5.1
+ * @author Ulf Lilleengen
*/
public class RemoteSessionTest {
@@ -52,7 +49,7 @@ public class RemoteSessionTest {
private Curator curator;
@Before
- public void setupTest() throws Exception {
+ public void setupTest() {
curator = new MockCurator();
}
@@ -66,7 +63,7 @@ public class RemoteSessionTest {
}
@Test
- public void require_that_applications_are_loaded() throws IOException, SAXException {
+ public void require_that_applications_are_loaded() {
RemoteSession session = createSession(3, Arrays.asList(new MockModelFactory(), new VespaModelFactory(new NullConfigModelRegistry())), Clock.systemUTC());
session.loadPrepared();
ApplicationSet applicationSet = session.ensureApplicationLoaded();
@@ -84,7 +81,7 @@ public class RemoteSessionTest {
}
@Test(expected = IllegalArgumentException.class)
- public void require_that_new_invalid_application_throws_exception() throws IOException, SAXException {
+ public void require_that_new_invalid_application_throws_exception() {
MockModelFactory failingFactory = new MockModelFactory();
failingFactory.vespaVersion = Version.fromIntValues(1, 2, 0);
failingFactory.throwOnLoad = true;
@@ -98,7 +95,7 @@ public class RemoteSessionTest {
}
@Test
- public void require_that_application_incompatible_with_latestmajor_is_loaded_on_earlier_major() throws IOException, SAXException {
+ public void require_that_application_incompatible_with_latestmajor_is_loaded_on_earlier_major() {
MockModelFactory okFactory1 = new MockModelFactory();
okFactory1.vespaVersion = Version.fromIntValues(1, 1, 0);
okFactory1.throwOnLoad = false;
@@ -116,7 +113,7 @@ public class RemoteSessionTest {
}
@Test
- public void require_that_old_invalid_application_does_not_throw_exception_if_skipped() throws IOException, SAXException {
+ public void require_that_old_invalid_application_does_not_throw_exception_if_skipped() {
MockModelFactory failingFactory = new MockModelFactory();
failingFactory.vespaVersion = Version.fromIntValues(1, 1, 0);
failingFactory.throwOnLoad = true;
@@ -131,7 +128,7 @@ public class RemoteSessionTest {
}
@Test
- public void require_that_old_invalid_application_does_not_throw_exception_if_skipped_also_across_major_versions() throws IOException, SAXException {
+ public void require_that_old_invalid_application_does_not_throw_exception_if_skipped_also_across_major_versions() {
MockModelFactory failingFactory = new MockModelFactory();
failingFactory.vespaVersion = Version.fromIntValues(1, 0, 0);
failingFactory.throwOnLoad = true;
@@ -146,7 +143,7 @@ public class RemoteSessionTest {
}
@Test
- public void require_that_old_invalid_application_does_not_throw_exception_if_skipped_also_when_new_major_is_incompatible() throws IOException, SAXException {
+ public void require_that_old_invalid_application_does_not_throw_exception_if_skipped_also_when_new_major_is_incompatible() {
MockModelFactory failingFactory = new MockModelFactory();
failingFactory.vespaVersion = Version.fromIntValues(1, 0, 0);
failingFactory.throwOnLoad = true;
@@ -166,7 +163,7 @@ public class RemoteSessionTest {
}
@Test
- public void require_that_an_application_package_can_limit_to_one_major_version() throws IOException, SAXException {
+ public void require_that_an_application_package_can_limit_to_one_major_version() {
ApplicationPackage application =
new MockApplicationPackage.Builder().withServices("<services major-version='2' version=\"1.0\"></services>").build();
@@ -186,7 +183,7 @@ public class RemoteSessionTest {
}
@Test
- public void require_that_session_status_is_updated() throws IOException, SAXException {
+ public void require_that_session_status_is_updated() {
SessionZooKeeperClient zkc = new MockSessionZKClient(curator, tenantName, 3);
RemoteSession session = createSession(3, zkc, Clock.systemUTC());
assertThat(session.getStatus(), is(Session.Status.NEW));
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionFactoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionFactoryTest.java
index 0ca487cfb67..6c8be2ac2f3 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionFactoryTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionFactoryTest.java
@@ -15,7 +15,8 @@ import com.yahoo.vespa.config.server.http.CompressedApplicationInputStream;
import com.yahoo.vespa.config.server.http.CompressedApplicationInputStreamTest;
import com.yahoo.vespa.config.server.http.v2.ApplicationApiHandler;
-import com.yahoo.vespa.config.server.tenant.TestWithTenant;
+import com.yahoo.vespa.config.server.tenant.TenantRepository;
+import com.yahoo.vespa.curator.mock.MockCurator;
import org.json.JSONException;
import org.json.JSONObject;
import org.junit.Before;
@@ -33,12 +34,13 @@ import static org.junit.Assert.assertTrue;
/**
* @author Ulf Lilleengen
*/
-public class SessionFactoryTest extends TestWithTenant {
+public class SessionFactoryTest {
private SessionFactory factory;
@Before
public void setup_test() {
- factory = tenant.getSessionFactory();
+ TenantRepository tenantRepository = new TenantRepository(new TestComponentRegistry.Builder().curator(new MockCurator()).build());
+ factory = tenantRepository.defaultTenant().getSessionFactory();
}
@Test
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java
index 4fbd7fe7232..92fb67fdd54 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java
@@ -32,11 +32,11 @@ import com.yahoo.vespa.config.server.provision.HostProvisionerProvider;
import com.yahoo.vespa.config.server.tenant.Rotations;
import com.yahoo.vespa.config.server.zookeeper.ConfigCurator;
+import com.yahoo.vespa.curator.mock.MockCurator;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
-import org.xml.sax.SAXException;
import java.io.File;
import java.io.IOException;
@@ -52,13 +52,15 @@ import static org.junit.Assert.*;
* @author lulf
* @since 5.1
*/
-public class SessionPreparerTest extends TestWithCurator {
+public class SessionPreparerTest {
private static final Path tenantPath = Path.createRoot();
private static final Path sessionsPath = tenantPath.append("sessions").append("testapp");
private static final File testApp = new File("src/test/apps/app");
private static final File invalidTestApp = new File("src/test/apps/illegalApp");
+ private MockCurator curator;
+ private ConfigCurator configCurator;
private SessionPreparer preparer;
private TestComponentRegistry componentRegistry;
private MockFileDistributionFactory fileDistributionFactory;
@@ -69,6 +71,8 @@ public class SessionPreparerTest extends TestWithCurator {
@Before
public void setUp() {
+ curator = new MockCurator();
+ configCurator = ConfigCurator.create(curator);
componentRegistry = new TestComponentRegistry.Builder().curator(curator).build();
fileDistributionFactory = (MockFileDistributionFactory)componentRegistry.getFileDistributionFactory();
preparer = createPreparer();
@@ -99,13 +103,13 @@ public class SessionPreparerTest extends TestWithCurator {
}
@Test(expected = InvalidApplicationException.class)
- public void require_that_application_validation_exception_is_not_caught() throws IOException, SAXException {
+ public void require_that_application_validation_exception_is_not_caught() throws IOException {
FilesApplicationPackage app = getApplicationPackage(invalidTestApp);
preparer.prepare(getContext(app), getLogger(), new PrepareParams.Builder().build(), Optional.empty(), tenantPath, Instant.now());
}
@Test
- public void require_that_application_validation_exception_is_ignored_if_forced() throws IOException, SAXException {
+ public void require_that_application_validation_exception_is_ignored_if_forced() throws IOException {
FilesApplicationPackage app = getApplicationPackage(invalidTestApp);
preparer.prepare(getContext(app), getLogger(),
new PrepareParams.Builder().ignoreValidationErrors(true).timeoutBudget(TimeoutBudgetTest.day()).build(),
@@ -250,18 +254,18 @@ public class SessionPreparerTest extends TestWithCurator {
return FilesApplicationPackage.fromFile(appDir);
}
- DeployHandlerLogger getLogger() {
+ private DeployHandlerLogger getLogger() {
return getLogger(false);
}
- DeployHandlerLogger getLogger(boolean verbose) {
+ private DeployHandlerLogger getLogger(boolean verbose) {
return new DeployHandlerLogger(new Slime().get(), verbose,
new ApplicationId.Builder().tenant("testtenant").applicationName("testapp").build());
}
private static class FailingModelFactory extends TestModelFactory {
private final RuntimeException exception;
- public FailingModelFactory(Version vespaVersion, RuntimeException exception) {
+ FailingModelFactory(Version vespaVersion, RuntimeException exception) {
super(vespaVersion);
this.exception = exception;
}
@@ -279,7 +283,7 @@ public class SessionPreparerTest extends TestWithCurator {
private static class ConfigChangeActionsModelFactory extends TestModelFactory {
private final ConfigChangeAction action;
- public ConfigChangeActionsModelFactory(Version vespaVersion, ConfigChangeAction action) {
+ ConfigChangeActionsModelFactory(Version vespaVersion, ConfigChangeAction action) {
super(vespaVersion);
this.action = action;
}
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClientTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClientTest.java
index 98ba3d4e178..522a21a47b3 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClientTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClientTest.java
@@ -4,8 +4,10 @@ package com.yahoo.vespa.config.server.session;
import com.yahoo.path.Path;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.text.Utf8;
-import com.yahoo.vespa.config.server.TestWithCurator;
import com.yahoo.vespa.config.server.zookeeper.ConfigCurator;
+import com.yahoo.vespa.curator.Curator;
+import com.yahoo.vespa.curator.mock.MockCurator;
+import org.junit.Before;
import org.junit.Test;
import java.util.concurrent.TimeUnit;
@@ -15,10 +17,18 @@ import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
/**
- * @author lulf
- * @since 5.1
+ * @author Ulf Lilleengen
*/
-public class SessionZooKeeperClientTest extends TestWithCurator {
+public class SessionZooKeeperClientTest {
+
+ private Curator curator;
+ private ConfigCurator configCurator;
+
+ @Before
+ public void setup() {
+ curator = new MockCurator();
+ configCurator = ConfigCurator.create(curator);
+ }
@Test
public void require_that_status_can_be_updated() {
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java
index f47ed69ad14..046369edce0 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java
@@ -8,9 +8,10 @@ import com.yahoo.config.provision.Version;
import com.yahoo.vespa.config.server.application.ApplicationSet;
import com.yahoo.vespa.config.server.ServerCache;
import com.yahoo.vespa.config.server.TestComponentRegistry;
-import com.yahoo.vespa.config.server.TestWithCurator;
import com.yahoo.vespa.config.server.application.Application;
import com.yahoo.vespa.config.server.monitoring.MetricUpdater;
+import com.yahoo.vespa.curator.Curator;
+import com.yahoo.vespa.curator.mock.MockCurator;
import com.yahoo.vespa.model.VespaModel;
import org.junit.After;
import org.junit.Before;
@@ -30,17 +31,20 @@ import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
-public class TenantRepositoryTest extends TestWithCurator {
+public class TenantRepositoryTest {
+ private static final TenantName tenant1 = TenantName.from("tenant1");
+ private static final TenantName tenant2 = TenantName.from("tenant2");
+ private static final TenantName tenant3 = TenantName.from("tenant3");
+
private TenantRepository tenantRepository;
private TestComponentRegistry globalComponentRegistry;
private TenantRequestHandlerTest.MockReloadListener listener;
private MockTenantListener tenantListener;
- private final TenantName tenant1 = TenantName.from("tenant1");
- private final TenantName tenant2 = TenantName.from("tenant2");
- private final TenantName tenant3 = TenantName.from("tenant3");
+ private Curator curator;
@Before
public void setupSessions() {
+ curator = new MockCurator();
globalComponentRegistry = new TestComponentRegistry.Builder().curator(curator).build();
listener = (TenantRequestHandlerTest.MockReloadListener)globalComponentRegistry.getReloadListener();
tenantListener = (MockTenantListener)globalComponentRegistry.getTenantListener();
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRequestHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRequestHandlerTest.java
index cecbab2d9ec..d517eb195a7 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRequestHandlerTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRequestHandlerTest.java
@@ -24,8 +24,6 @@ import com.yahoo.vespa.config.server.host.HostRegistries;
import com.yahoo.vespa.config.server.ReloadListener;
import com.yahoo.vespa.config.server.ServerCache;
import com.yahoo.vespa.config.server.TestComponentRegistry;
-import com.yahoo.vespa.config.server.TestConfigDefinitionRepo;
-import com.yahoo.vespa.config.server.TestWithCurator;
import com.yahoo.vespa.config.server.rpc.UncompressedConfigResponseFactory;
import com.yahoo.vespa.config.server.application.Application;
import com.yahoo.config.provision.ApplicationId;
@@ -37,6 +35,8 @@ import com.yahoo.vespa.config.server.monitoring.MetricUpdater;
import com.yahoo.vespa.config.server.monitoring.Metrics;
import com.yahoo.vespa.config.server.session.RemoteSession;
import com.yahoo.vespa.config.server.session.SessionZooKeeperClient;
+import com.yahoo.vespa.curator.Curator;
+import com.yahoo.vespa.curator.mock.MockCurator;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.VespaModelFactory;
@@ -58,7 +58,7 @@ import static org.junit.Assert.*;
/**
* @author Ulf Lilleengen
*/
-public class TenantRequestHandlerTest extends TestWithCurator {
+public class TenantRequestHandlerTest {
private static final Version vespaVersion = new VespaModelFactory(new NullConfigModelRegistry()).getVersion();
private TenantRequestHandler server;
@@ -67,6 +67,7 @@ public class TenantRequestHandlerTest extends TestWithCurator {
private File app2 = new File("src/test/apps/cs2");
private TenantName tenant = TenantName.from("mytenant");
private TestComponentRegistry componentRegistry;
+ private Curator curator;
@Rule
public TemporaryFolder tempFolder = new TemporaryFolder();
@@ -77,6 +78,8 @@ public class TenantRequestHandlerTest extends TestWithCurator {
@Before
public void setUp() throws IOException {
+ curator = new MockCurator();
+
feedApp(app1, 1, defaultApp(), false);
Metrics sh = Metrics.createTestMetrics();
List<ReloadListener> listeners = new ArrayList<>();
@@ -86,10 +89,7 @@ public class TenantRequestHandlerTest extends TestWithCurator {
}
private void feedApp(File appDir, long sessionId, ApplicationId appId, boolean internalRedeploy) throws IOException {
- SessionZooKeeperClient zkc = new SessionZooKeeperClient(curator, configCurator,
- TenantRepository.getSessionsPath(tenant).append(String.valueOf(sessionId)),
- new TestConfigDefinitionRepo(),
- "", Optional.empty());
+ SessionZooKeeperClient zkc = new SessionZooKeeperClient(curator, TenantRepository.getSessionsPath(tenant).append(String.valueOf(sessionId)));
zkc.writeApplicationId(appId);
File app = tempFolder.newFolder();
IOUtils.copyDirectory(appDir, app);
@@ -107,17 +107,14 @@ public class TenantRequestHandlerTest extends TestWithCurator {
AllocatedHosts.withHosts(Collections.emptySet()));
}
- private ApplicationSet reloadConfig(long id, Clock clock) {
- return reloadConfig(id, "default", clock);
+ private ApplicationSet reloadConfig(long sessionId, Clock clock) {
+ return reloadConfig(sessionId, "default", clock);
}
- private ApplicationSet reloadConfig(long id, String application, Clock clock) {
- SessionZooKeeperClient zkc = new SessionZooKeeperClient(curator, configCurator,
- TenantRepository.getSessionsPath(tenant).append(String.valueOf(id)),
- new TestConfigDefinitionRepo(),
- "", Optional.empty());
+ private ApplicationSet reloadConfig(long sessionId, String application, Clock clock) {
+ SessionZooKeeperClient zkc = new SessionZooKeeperClient(curator, TenantRepository.getSessionsPath(tenant).append(String.valueOf(sessionId)));
zkc.writeApplicationId(new ApplicationId.Builder().tenant(tenant).applicationName(application).build());
- RemoteSession session = new RemoteSession(tenant, id, componentRegistry, zkc, clock);
+ RemoteSession session = new RemoteSession(tenant, sessionId, componentRegistry, zkc, clock);
return session.ensureApplicationLoaded();
}
@@ -207,10 +204,7 @@ public class TenantRequestHandlerTest extends TestWithCurator {
@Test
public void testResolveForAppId() {
long id = 1L;
- SessionZooKeeperClient zkc = new SessionZooKeeperClient(curator, configCurator,
- TenantRepository.getSessionsPath(tenant).append(String.valueOf(id)),
- new TestConfigDefinitionRepo(),
- "", Optional.empty());
+ SessionZooKeeperClient zkc = new SessionZooKeeperClient(curator, TenantRepository.getSessionsPath(tenant).append(String.valueOf(id)));
ApplicationId appId = new ApplicationId.Builder()
.tenant(tenant)
.applicationName("myapp").instanceName("myinst").build();
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java
index 1975899355c..1b3afeb353b 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java
@@ -4,7 +4,6 @@ package com.yahoo.vespa.config.server.tenant;
import com.google.common.testing.EqualsTester;
import com.yahoo.config.provision.TenantName;
import com.yahoo.vespa.config.server.TestComponentRegistry;
-import com.yahoo.vespa.config.server.TestWithCurator;
import com.yahoo.vespa.config.server.application.MemoryTenantApplications;
import org.junit.Before;
import org.junit.Test;
@@ -14,10 +13,9 @@ import static org.hamcrest.Matchers.is;
import static org.junit.Assert.*;
/**
- * @author lulf
- * @since 5.3
+ * @author Ulf Lilleengen
*/
-public class TenantTest extends TestWithCurator {
+public class TenantTest {
private final TestComponentRegistry componentRegistry = new TestComponentRegistry.Builder().build();
private Tenant t1;
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TestWithTenant.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TestWithTenant.java
deleted file mode 100644
index 67fb320d821..00000000000
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TestWithTenant.java
+++ /dev/null
@@ -1,25 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.config.server.tenant;
-
-import com.yahoo.vespa.config.server.TestComponentRegistry;
-import com.yahoo.vespa.config.server.TestWithCurator;
-import org.junit.Before;
-
-/**
- * Utility for a test using a single default tenant.
- *
- * @author lulf
- * @since 5.35
- */
-public class TestWithTenant extends TestWithCurator {
-
- protected TenantRepository tenantRepository;
- protected Tenant tenant;
-
- @Before
- public void setupTenant() throws Exception {
- tenantRepository = new TenantRepository(new TestComponentRegistry.Builder().curator(curator).build());
- tenant = tenantRepository.defaultTenant();
- }
-
-}
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/InitializedCounterTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/InitializedCounterTest.java
index b444e09f558..888cbb7a68b 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/InitializedCounterTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/InitializedCounterTest.java
@@ -1,11 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.config.server.zookeeper;
-import com.yahoo.vespa.config.server.TestWithCurator;
-import org.junit.Before;
-import org.junit.Rule;
+import com.yahoo.vespa.curator.mock.MockCurator;
import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
@@ -13,20 +10,15 @@ import static org.junit.Assert.assertThat;
/**
* @author Ulf Lilleengen
*/
-public class InitializedCounterTest extends TestWithCurator {
+public class InitializedCounterTest {
- @Rule
- public TemporaryFolder folder = new TemporaryFolder();
-
- @Before
- public void setupZK() {
+ @Test
+ public void requireThatCounterIsInitializedFromNumberOfSessions() {
+ ConfigCurator configCurator = ConfigCurator.create(new MockCurator());
configCurator.createNode("/sessions");
configCurator.createNode("/sessions/1");
configCurator.createNode("/sessions/2");
- }
- @Test
- public void requireThatCounterIsInitializedFromNumberOfSessions() {
InitializedCounter counter = new InitializedCounter(configCurator, "/counter", "/sessions");
assertThat(counter.counter.get(), is(2l));
}
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackageTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackageTest.java
index f0c74d19af9..06908dbab51 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackageTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackageTest.java
@@ -21,14 +21,15 @@ import com.yahoo.config.provision.Version;
import com.yahoo.config.provisioning.FlavorsConfig;
import com.yahoo.path.Path;
import com.yahoo.text.Utf8;
-import com.yahoo.vespa.config.server.TestWithCurator;
+import com.yahoo.vespa.curator.mock.MockCurator;
+import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import com.yahoo.io.IOUtils;
-public class ZKApplicationPackageTest extends TestWithCurator {
+public class ZKApplicationPackageTest {
private static final String APP = "src/test/apps/zkapp";
private static final String TEST_FLAVOR_NAME = "test-flavor";
@@ -37,9 +38,16 @@ public class ZKApplicationPackageTest extends TestWithCurator {
Collections.singleton(new HostSpec("foo.yahoo.com", Collections.emptyList(), TEST_FLAVOR, Optional.empty(),
Optional.of(com.yahoo.component.Version.fromString("6.0.1")))));
+ private ConfigCurator configCurator;
+
@Rule
public TemporaryFolder tmpDir = new TemporaryFolder();
+ @Before
+ public void setup() {
+ configCurator = ConfigCurator.create(new MockCurator());
+ }
+
@Test
public void testBasicZKFeed() throws IOException {
feed(configCurator, new File(APP));
diff --git a/container-core/src/main/java/com/yahoo/container/handler/VipStatus.java b/container-core/src/main/java/com/yahoo/container/handler/VipStatus.java
index bcd6e930ee3..d7457140dae 100644
--- a/container-core/src/main/java/com/yahoo/container/handler/VipStatus.java
+++ b/container-core/src/main/java/com/yahoo/container/handler/VipStatus.java
@@ -6,6 +6,7 @@ import java.util.Map;
import com.google.inject.Inject;
import com.yahoo.container.QrSearchersConfig;
+import com.yahoo.container.core.VipStatusConfig;
/**
* API for programmatically removing the container from VIP rotation.
@@ -15,15 +16,22 @@ import com.yahoo.container.QrSearchersConfig;
public class VipStatus {
private final Map<Object, Boolean> clusters = new IdentityHashMap<>();
+ private final VipStatusConfig vipStatusConfig;
public VipStatus() {
- this(null);
+ this(null, new VipStatusConfig(new VipStatusConfig.Builder()));
}
- @Inject
public VipStatus(QrSearchersConfig dispatchers) {
+ this(dispatchers, new VipStatusConfig(new VipStatusConfig.Builder()));
+ }
+
+ // TODO: Why use QrSearchersConfig here? Remove and inject ComponentRegistry<ClusterSearcher> instead?
+ @Inject
+ public VipStatus(QrSearchersConfig dispatchers, VipStatusConfig vipStatusConfig) {
// the config is not used for anything, it's just a dummy to create a
// dependency link to which dispatchers are used
+ this.vipStatusConfig = vipStatusConfig;
}
/**
@@ -55,14 +63,14 @@ public class VipStatus {
/**
* Tell whether the container is connected to any active services at all.
*
- * @return true if at least one service or cluster is up, or if no services
+ * @return true if at least one service or cluster is up, or value is taken from config if no services
* are registered (yet)
*/
public boolean isInRotation() {
synchronized (clusters) {
- // if no stored state, try serving
+ // if no stored state, use config to decide whether to serve or not
if (clusters.size() == 0) {
- return true;
+ return vipStatusConfig.initiallyInRotation();
}
for (Boolean inRotation : clusters.values()) {
if (inRotation) {
diff --git a/container-core/src/main/resources/configdefinitions/vip-status.def b/container-core/src/main/resources/configdefinitions/vip-status.def
index 44da7292f05..1e364419ab8 100644
--- a/container-core/src/main/resources/configdefinitions/vip-status.def
+++ b/container-core/src/main/resources/configdefinitions/vip-status.def
@@ -6,9 +6,12 @@ namespace=container.core
## rotation, ignoring any status file.
noSearchBackendsImpliesOutOfService bool default=true
-## Whether to return hard coded reply or serve "status.html" from disk
+## Whether to return hard-coded reply or serve "status.html" from disk
accessdisk bool default=false
## The file to serve as the status file.
-## If the paht is relative vespa home is prepended
+## If the path is relative vespa home is prepended
statusfile string default="share/qrsdocs/status.html"
+
+## The initial rotation state when no information is known about backend clusters
+initiallyInRotation bool default=true
diff --git a/container-dependency-versions/pom.xml b/container-dependency-versions/pom.xml
index b4af6800768..f546a4e36d2 100644
--- a/container-dependency-versions/pom.xml
+++ b/container-dependency-versions/pom.xml
@@ -459,7 +459,7 @@
<properties>
<bouncycastle.version>1.58</bouncycastle.version>
- <felix.version>5.0.1</felix.version>
+ <felix.version>5.4.0</felix.version>
<findbugs.version>1.3.9</findbugs.version>
<guava.version>18.0</guava.version>
<guice.version>3.0</guice.version>
diff --git a/container-dev/pom.xml b/container-dev/pom.xml
index 53153a05c4a..ff67e8db9fe 100644
--- a/container-dev/pom.xml
+++ b/container-dev/pom.xml
@@ -123,10 +123,6 @@
<version>${project.version}</version>
<exclusions>
<exclusion>
- <groupId>org.ow2.asm</groupId>
- <artifactId>asm</artifactId>
- </exclusion>
- <exclusion>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
</exclusion>
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetrics.java b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetrics.java
new file mode 100644
index 00000000000..04fd8572ad4
--- /dev/null
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetrics.java
@@ -0,0 +1,94 @@
+package com.yahoo.container.jdisc.metric;
+
+import com.yahoo.jdisc.Metric;
+
+import java.lang.management.GarbageCollectorMXBean;
+import java.lang.management.ManagementFactory;
+import java.time.Clock;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.Map;
+
+/**
+ * @author ollivir
+ */
+public class GarbageCollectionMetrics {
+ private static final String GC_COUNT = "jdisc.gc.count";
+ private static final String GC_TIME = "jdisc.gc.ms";
+ private static final String DIMENSION_KEY = "gcName";
+
+ public static final Duration REPORTING_INTERVAL = Duration.ofSeconds(62);
+
+ static class GcStats {
+ private final Instant when;
+ private final long count;
+ private final Duration totalRuntime;
+
+ private GcStats(Instant when, long count, Duration totalRuntime) {
+ this.when = when;
+ this.count = count;
+ this.totalRuntime = totalRuntime;
+ }
+ }
+
+ private Map<String, LinkedList<GcStats>> gcStatistics;
+
+ private final Clock clock;
+
+ public GarbageCollectionMetrics(Clock clock) {
+ this.clock = clock;
+ this.gcStatistics = new HashMap<>();
+ collectGcStatistics(clock.instant());
+ }
+
+ private void collectGcStatistics(Instant now) {
+ for (GarbageCollectorMXBean gcBean : ManagementFactory.getGarbageCollectorMXBeans()) {
+ String gcName = gcBean.getName().replace(" ", "");
+ GcStats stats = new GcStats(now, gcBean.getCollectionCount(), Duration.ofMillis(gcBean.getCollectionTime()));
+
+ LinkedList<GcStats> window = gcStatistics.computeIfAbsent(gcName, anyName -> new LinkedList<>());
+ window.addLast(stats);
+ }
+ }
+
+ private void cleanStatistics(Instant now) {
+ Instant oldestToKeep = now.minus(REPORTING_INTERVAL);
+
+ for(Iterator<Map.Entry<String, LinkedList<GcStats>>> it = gcStatistics.entrySet().iterator(); it.hasNext(); ) {
+ Map.Entry<String, LinkedList<GcStats>> entry = it.next();
+ LinkedList<GcStats> history = entry.getValue();
+ while(history.isEmpty() == false && oldestToKeep.isAfter(history.getFirst().when)) {
+ history.removeFirst();
+ }
+ if(history.isEmpty()) {
+ it.remove();
+ }
+ }
+ }
+
+ public void emitMetrics(Metric metric) {
+ Instant now = clock.instant();
+
+ collectGcStatistics(now);
+ cleanStatistics(now);
+
+ for (Map.Entry<String, LinkedList<GcStats>> item : gcStatistics.entrySet()) {
+ GcStats reference = item.getValue().getFirst();
+ GcStats latest = item.getValue().getLast();
+ Map<String, String> contextData = new HashMap<>();
+ contextData.put(DIMENSION_KEY, item.getKey());
+ Metric.Context gcContext = metric.createContext(contextData);
+
+ metric.set(GC_COUNT, latest.count - reference.count, gcContext);
+ metric.set(GC_TIME, latest.totalRuntime.minus(reference.totalRuntime).toMillis(), gcContext);
+ }
+ }
+
+ // partial exposure for testing
+ Map<String, LinkedList<GcStats>> getGcStatistics() {
+ return gcStatistics;
+ }
+}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java
index 22b049c9ab7..c2ef789e8fc 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java
@@ -10,6 +10,7 @@ import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
+import java.time.Clock;
import java.time.Duration;
import java.util.Timer;
import java.util.TimerTask;
@@ -89,10 +90,12 @@ public class MetricUpdater extends AbstractComponent {
private final Runtime runtime = Runtime.getRuntime();
private final Metric metric;
private final ContainerWatchdogMetrics containerWatchdogMetrics;
+ private final GarbageCollectionMetrics garbageCollectionMetrics;
public UpdaterTask(Metric metric, ContainerWatchdogMetrics containerWatchdogMetrics) {
this.metric = metric;
this.containerWatchdogMetrics = containerWatchdogMetrics;
+ this.garbageCollectionMetrics = new GarbageCollectionMetrics(Clock.systemUTC());
}
@SuppressWarnings("deprecation")
@@ -109,9 +112,10 @@ public class MetricUpdater extends AbstractComponent {
metric.set(TOTAL_MEMORY_BYTES, totalMemory, null);
metric.set(MEMORY_MAPPINGS_COUNT, count_mappings(), null);
metric.set(OPEN_FILE_DESCRIPTORS, count_open_files(), null);
+
containerWatchdogMetrics.emitMetrics(metric);
+ garbageCollectionMetrics.emitMetrics(metric);
}
-
}
private static class TimerScheduler implements Scheduler {
diff --git a/container-disc/src/test/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetricsTest.java b/container-disc/src/test/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetricsTest.java
new file mode 100644
index 00000000000..61d8763b852
--- /dev/null
+++ b/container-disc/src/test/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetricsTest.java
@@ -0,0 +1,57 @@
+package com.yahoo.container.jdisc.metric;
+
+import com.yahoo.jdisc.Metric;
+import com.yahoo.test.ManualClock;
+import org.junit.Test;
+
+import java.lang.management.ManagementFactory;
+import java.time.Duration;
+import java.util.LinkedList;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+/**
+ * @author ollivir
+ */
+public class GarbageCollectionMetricsTest {
+ @Test
+ public void gc_metrics_are_collected_in_a_sliding_window() {
+ ManualClock clock = new ManualClock();
+ Metric metric = mock(Metric.class);
+ int garbageCollectors = ManagementFactory.getGarbageCollectorMXBeans().size();
+
+ Duration interval = GarbageCollectionMetrics.REPORTING_INTERVAL;
+ GarbageCollectionMetrics garbageCollectionMetrics = new GarbageCollectionMetrics(clock);
+ assertThat(garbageCollectionMetrics.getGcStatistics().keySet().size(), is(garbageCollectors));
+
+ clock.advance(interval.minus(Duration.ofMillis(10)));
+ garbageCollectionMetrics.emitMetrics(metric);
+ assertWindowLengths(garbageCollectionMetrics, 2);
+
+ clock.advance(Duration.ofMillis(10));
+ garbageCollectionMetrics.emitMetrics(metric);
+ assertWindowLengths(garbageCollectionMetrics, 3);
+
+ clock.advance(Duration.ofMillis(10));
+ garbageCollectionMetrics.emitMetrics(metric);
+ assertWindowLengths(garbageCollectionMetrics, 3);
+
+ clock.advance(interval);
+ garbageCollectionMetrics.emitMetrics(metric);
+ assertWindowLengths(garbageCollectionMetrics, 2);
+
+ verify(metric, times(garbageCollectors * 4 * 2)).set(anyString(), any(), any());
+ }
+
+ private static void assertWindowLengths(GarbageCollectionMetrics gcm, int count) {
+ for(LinkedList<GarbageCollectionMetrics.GcStats> window: gcm.getGcStatistics().values()) {
+ assertThat(window.size(), is(count));
+ }
+ }
+}
diff --git a/container-disc/src/test/java/com/yahoo/container/jdisc/metric/MetricUpdaterTest.java b/container-disc/src/test/java/com/yahoo/container/jdisc/metric/MetricUpdaterTest.java
index f10af7593a4..e9e04eab3b4 100644
--- a/container-disc/src/test/java/com/yahoo/container/jdisc/metric/MetricUpdaterTest.java
+++ b/container-disc/src/test/java/com/yahoo/container/jdisc/metric/MetricUpdaterTest.java
@@ -5,6 +5,7 @@ import com.yahoo.jdisc.Metric;
import com.yahoo.jdisc.statistics.ContainerWatchdogMetrics;
import org.junit.Test;
+import java.lang.management.ManagementFactory;
import java.time.Duration;
import static org.mockito.Matchers.any;
@@ -20,11 +21,13 @@ public class MetricUpdaterTest {
@Test
public void metrics_are_updated_in_scheduler_cycle() throws InterruptedException {
+ int gcCount = ManagementFactory.getGarbageCollectorMXBeans().size();
+
Metric metric = mock(Metric.class);
ContainerWatchdogMetrics containerWatchdogMetrics = mock(ContainerWatchdogMetrics.class);
new MetricUpdater(new MockScheduler(), metric, containerWatchdogMetrics);
verify(containerWatchdogMetrics, times(1)).emitMetrics(any());
- verify(metric, times(8)).set(anyString(), any(), any());
+ verify(metric, times(8 + 2 * gcCount)).set(anyString(), any(), any());
}
private static class MockScheduler implements MetricUpdater.Scheduler {
diff --git a/container-jersey2/pom.xml b/container-jersey2/pom.xml
index 26dfa762032..c5ed7d872bf 100644
--- a/container-jersey2/pom.xml
+++ b/container-jersey2/pom.xml
@@ -52,11 +52,6 @@
<dependency>
<groupId>org.ow2.asm</groupId>
<artifactId>asm</artifactId>
- <version>5.0.3</version>
- </dependency>
- <dependency>
- <groupId>org.scala-lang</groupId>
- <artifactId>scala-library</artifactId>
</dependency>
</dependencies>
<build>
diff --git a/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ComponentGraphProvider.java b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ComponentGraphProvider.java
new file mode 100644
index 00000000000..7ff9646cb27
--- /dev/null
+++ b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ComponentGraphProvider.java
@@ -0,0 +1,73 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.container.servlet.jersey;
+
+import com.yahoo.container.di.config.ResolveDependencyException;
+import com.yahoo.container.di.config.RestApiContext;
+import com.yahoo.container.jaxrs.annotation.Component;
+import org.glassfish.hk2.api.Injectee;
+import org.glassfish.hk2.api.InjectionResolver;
+import org.glassfish.hk2.api.ServiceHandle;
+
+import javax.inject.Singleton;
+
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Resolves jdisc container components for jersey 2 components.
+ *
+ * @author Tony Vaagenes
+ * @author ollivir
+ */
+@Singleton // jersey2 requirement: InjectionResolvers must be in the Singleton scope
+public class ComponentGraphProvider implements InjectionResolver<Component> {
+ private Collection<RestApiContext.Injectable> injectables;
+
+ public ComponentGraphProvider(Collection<RestApiContext.Injectable> injectables) {
+ this.injectables = injectables;
+ }
+
+ @Override
+ public Object resolve(Injectee injectee, ServiceHandle<?> root) {
+ Class<?> wantedClass;
+ Type type = injectee.getRequiredType();
+ if (type instanceof Class) {
+ wantedClass = (Class<?>) type;
+ } else {
+ throw new UnsupportedOperationException("Only classes are supported, got " + type);
+ }
+
+ List<RestApiContext.Injectable> componentsWithMatchingType = new ArrayList<>();
+ for (RestApiContext.Injectable injectable : injectables) {
+ if (wantedClass.isInstance(injectable.instance)) {
+ componentsWithMatchingType.add(injectable);
+ }
+ }
+
+ if (componentsWithMatchingType.size() == 1) {
+ return componentsWithMatchingType.get(0).instance;
+ } else {
+ String injectionDescription = "class '" + wantedClass + "' to inject into Jersey resource/provider '"
+ + injectee.getInjecteeClass() + "')";
+ if (componentsWithMatchingType.size() > 1) {
+ String ids = componentsWithMatchingType.stream().map(c -> c.id.toString()).collect(Collectors.joining(","));
+ throw new ResolveDependencyException("Multiple components found of " + injectionDescription + ": " + ids);
+ } else {
+ throw new ResolveDependencyException("Could not find a component of " + injectionDescription + ".");
+ }
+ }
+ }
+
+ @Override
+ public boolean isMethodParameterIndicator() {
+ return true;
+ }
+
+ @Override
+ public boolean isConstructorParameterIndicator() {
+ return true;
+ }
+}
diff --git a/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyApplication.java b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyApplication.java
new file mode 100644
index 00000000000..4c4e43bc8d5
--- /dev/null
+++ b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyApplication.java
@@ -0,0 +1,25 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.container.servlet.jersey;
+
+import javax.ws.rs.core.Application;
+
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Set;
+
+/**
+ * @author Tony Vaagenes
+ * @author ollivir
+ */
+public class JerseyApplication extends Application {
+ private Set<Class<?>> classes;
+
+ public JerseyApplication(Collection<Class<?>> resourcesAndProviderClasses) {
+ this.classes = new HashSet<>(resourcesAndProviderClasses);
+ }
+
+ @Override
+ public Set<Class<?>> getClasses() {
+ return classes;
+ }
+}
diff --git a/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyServletProvider.java b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyServletProvider.java
new file mode 100644
index 00000000000..1dbe410ba54
--- /dev/null
+++ b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyServletProvider.java
@@ -0,0 +1,118 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.container.servlet.jersey;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
+import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
+import com.fasterxml.jackson.jaxrs.json.JacksonJaxbJsonProvider;
+import com.yahoo.container.di.componentgraph.Provider;
+import com.yahoo.container.di.config.RestApiContext;
+import com.yahoo.container.di.config.RestApiContext.BundleInfo;
+import com.yahoo.container.jaxrs.annotation.Component;
+import org.eclipse.jetty.servlet.ServletHolder;
+import org.glassfish.hk2.api.InjectionResolver;
+import org.glassfish.hk2.api.TypeLiteral;
+import org.glassfish.hk2.utilities.Binder;
+import org.glassfish.hk2.utilities.binding.AbstractBinder;
+import org.glassfish.jersey.media.multipart.MultiPartFeature;
+import org.glassfish.jersey.server.ResourceConfig;
+import org.glassfish.jersey.servlet.ServletContainer;
+import org.objectweb.asm.ClassReader;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Optional;
+
+import static com.yahoo.container.servlet.jersey.util.ResourceConfigUtil.registerComponent;
+
+/**
+ * @author Tony Vaagenes
+ * @author ollivir
+ */
+public class JerseyServletProvider implements Provider<ServletHolder> {
+ private final ServletHolder jerseyServletHolder;
+
+ public JerseyServletProvider(RestApiContext restApiContext) {
+ this.jerseyServletHolder = new ServletHolder(new ServletContainer(resourceConfig(restApiContext)));
+ }
+
+ private ResourceConfig resourceConfig(RestApiContext restApiContext) {
+ final ResourceConfig resourceConfig = ResourceConfig
+ .forApplication(new JerseyApplication(resourcesAndProviders(restApiContext.getBundles())));
+
+ registerComponent(resourceConfig, componentInjectorBinder(restApiContext));
+ registerComponent(resourceConfig, jacksonDatatypeJdk8Provider());
+ resourceConfig.register(MultiPartFeature.class);
+
+ return resourceConfig;
+ }
+
+ private static Collection<Class<?>> resourcesAndProviders(Collection<BundleInfo> bundles) {
+ final List<Class<?>> ret = new ArrayList<>();
+
+ for (BundleInfo bundle : bundles) {
+ for (String classEntry : bundle.getClassEntries()) {
+ Optional<String> className = detectResourceOrProvider(bundle.classLoader, classEntry);
+ className.ifPresent(cname -> ret.add(loadClass(bundle.symbolicName, bundle.classLoader, cname)));
+ }
+ }
+ return ret;
+ }
+
+ private static Optional<String> detectResourceOrProvider(ClassLoader bundleClassLoader, String classEntry) {
+ try (InputStream inputStream = getResourceAsStream(bundleClassLoader, classEntry)) {
+ ResourceOrProviderClassVisitor visitor = ResourceOrProviderClassVisitor.visit(new ClassReader(inputStream));
+ return visitor.getJerseyClassName();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static InputStream getResourceAsStream(ClassLoader bundleClassLoader, String classEntry) {
+ InputStream is = bundleClassLoader.getResourceAsStream(classEntry);
+ if (is == null) {
+ throw new RuntimeException("No entry " + classEntry + " in bundle " + bundleClassLoader);
+ } else {
+ return is;
+ }
+ }
+
+ private static Class<?> loadClass(String bundleSymbolicName, ClassLoader classLoader, String className) {
+ try {
+ return classLoader.loadClass(className);
+ } catch (Exception e) {
+ throw new RuntimeException("Failed loading class " + className + " from bundle " + bundleSymbolicName, e);
+ }
+ }
+
+ private static Binder componentInjectorBinder(RestApiContext restApiContext) {
+ final ComponentGraphProvider componentGraphProvider = new ComponentGraphProvider(restApiContext.getInjectableComponents());
+ final TypeLiteral<InjectionResolver<Component>> componentAnnotationType = new TypeLiteral<InjectionResolver<Component>>() {
+ };
+
+ return new AbstractBinder() {
+ @Override
+ public void configure() {
+ bind(componentGraphProvider).to(componentAnnotationType);
+ }
+ };
+ }
+
+ private static JacksonJaxbJsonProvider jacksonDatatypeJdk8Provider() {
+ JacksonJaxbJsonProvider provider = new JacksonJaxbJsonProvider();
+ provider.setMapper(new ObjectMapper().registerModule(new Jdk8Module()).registerModule(new JavaTimeModule()));
+ return provider;
+ }
+
+ @Override
+ public ServletHolder get() {
+ return jerseyServletHolder;
+ }
+
+ @Override
+ public void deconstruct() {
+ }
+}
diff --git a/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.java b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.java
new file mode 100644
index 00000000000..7cb47ac6118
--- /dev/null
+++ b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.java
@@ -0,0 +1,103 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.container.servlet.jersey;
+
+import org.objectweb.asm.AnnotationVisitor;
+import org.objectweb.asm.ClassReader;
+import org.objectweb.asm.ClassVisitor;
+import org.objectweb.asm.Opcodes;
+import org.objectweb.asm.Type;
+
+import javax.ws.rs.Path;
+import javax.ws.rs.ext.Provider;
+
+import java.util.HashSet;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * @author Tony Vaagenes
+ * @author ollivir
+ */
+public class ResourceOrProviderClassVisitor extends ClassVisitor {
+ private String className = null;
+ private boolean isPublic = false;
+ private boolean isAbstract = false;
+
+ private boolean isInnerClass = false;
+ private boolean isStatic = false;
+
+ private boolean isAnnotated = false;
+
+ public ResourceOrProviderClassVisitor() {
+ super(Opcodes.ASM6);
+ }
+
+ public Optional<String> getJerseyClassName() {
+ if (isJerseyClass()) {
+ return Optional.of(getClassName());
+ } else {
+ return Optional.empty();
+ }
+ }
+
+ public boolean isJerseyClass() {
+ return isAnnotated && isPublic && !isAbstract && (!isInnerClass || isStatic);
+ }
+
+ public String getClassName() {
+ assert (className != null);
+ return org.objectweb.asm.Type.getObjectType(className).getClassName();
+ }
+
+ @Override
+ public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
+ isPublic = isPublic(access);
+ className = name;
+ isAbstract = isAbstract(access);
+ }
+
+ @Override
+ public void visitInnerClass(String name, String outerName, String innerName, int access) {
+ assert (className != null);
+
+ if (name.equals(className)) {
+ isInnerClass = true;
+ isStatic = isStatic(access);
+ }
+ }
+
+ @Override
+ public AnnotationVisitor visitAnnotation(String desc, boolean visible) {
+ isAnnotated |= annotationClassDescriptors.contains(desc);
+ return null;
+ }
+
+ private static Set<String> annotationClassDescriptors = new HashSet<>();
+
+ static {
+ annotationClassDescriptors.add(Type.getDescriptor(Path.class));
+ annotationClassDescriptors.add(Type.getDescriptor(Provider.class));
+ }
+
+ private static boolean isPublic(int access) {
+ return isSet(Opcodes.ACC_PUBLIC, access);
+ }
+
+ private static boolean isStatic(int access) {
+ return isSet(Opcodes.ACC_STATIC, access);
+ }
+
+ private static boolean isAbstract(int access) {
+ return isSet(Opcodes.ACC_ABSTRACT, access);
+ }
+
+ private static boolean isSet(int bits, int access) {
+ return (access & bits) == bits;
+ }
+
+ public static ResourceOrProviderClassVisitor visit(ClassReader classReader) {
+ ResourceOrProviderClassVisitor visitor = new ResourceOrProviderClassVisitor();
+ classReader.accept(visitor, ClassReader.SKIP_DEBUG | ClassReader.SKIP_CODE | ClassReader.SKIP_FRAMES);
+ return visitor;
+ }
+}
diff --git a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ComponentGraphProvider.scala b/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ComponentGraphProvider.scala
deleted file mode 100644
index cabde3680a4..00000000000
--- a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ComponentGraphProvider.scala
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.container.servlet.jersey
-
-import javax.inject.Singleton
-
-import com.yahoo.container.di.config.{ResolveDependencyException, RestApiContext}
-import com.yahoo.container.jaxrs.annotation.Component
-import org.glassfish.hk2.api.{ServiceHandle, Injectee, InjectionResolver}
-
-/**
- * Resolves jdisc container components for jersey 2 components.
- * Similar to Gjoran's ComponentGraphProvider for jersey 1.
- * @author tonytv
- */
-@Singleton //jersey2 requirement: InjectionResolvers must be in the Singleton scope
-class ComponentGraphProvider(injectables: Traversable[RestApiContext.Injectable]) extends InjectionResolver[Component] {
- override def resolve(injectee: Injectee, root: ServiceHandle[_]): AnyRef = {
- val wantedClass = injectee.getRequiredType match {
- case c: Class[_] => c
- case unsupported => throw new UnsupportedOperationException("Only classes are supported, got " + unsupported)
- }
-
- val componentsWithMatchingType = injectables.filter{ injectable =>
- wantedClass.isInstance(injectable.instance) }
-
- val injectionDescription =
- s"class '$wantedClass' to inject into Jersey resource/provider '${injectee.getInjecteeClass}')"
-
- if (componentsWithMatchingType.size > 1)
- throw new ResolveDependencyException(s"Multiple components found of $injectionDescription: " +
- componentsWithMatchingType.map(_.id).mkString(","))
-
- componentsWithMatchingType.headOption.map(_.instance).getOrElse {
- throw new ResolveDependencyException(s"Could not find a component of $injectionDescription.")
- }
- }
-
- override def isMethodParameterIndicator: Boolean = true
- override def isConstructorParameterIndicator: Boolean = true
-}
diff --git a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyApplication.scala b/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyApplication.scala
deleted file mode 100644
index eea41003984..00000000000
--- a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyApplication.scala
+++ /dev/null
@@ -1,16 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.container.servlet.jersey
-
-import javax.ws.rs.core.Application
-
-import scala.collection.JavaConverters._
-
-/**
- * @author tonytv
- */
-class JerseyApplication(resourcesAndProviderClasses: Set[Class[_]]) extends Application {
- private val classes: java.util.Set[Class[_]] = resourcesAndProviderClasses.asJava
-
- override def getClasses = classes
- override def getSingletons = super.getSingletons
-}
diff --git a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyServletProvider.scala b/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyServletProvider.scala
deleted file mode 100644
index f0eff54dc16..00000000000
--- a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyServletProvider.scala
+++ /dev/null
@@ -1,109 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.container.servlet.jersey
-
-import java.io.{IOException, InputStream}
-
-import com.fasterxml.jackson.databind.ObjectMapper
-import com.fasterxml.jackson.datatype.jdk8.Jdk8Module
-import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
-import com.fasterxml.jackson.jaxrs.json.JacksonJaxbJsonProvider
-import com.yahoo.container.di.componentgraph.Provider
-import com.yahoo.container.di.config.RestApiContext
-import com.yahoo.container.di.config.RestApiContext.BundleInfo
-import com.yahoo.container.jaxrs.annotation.Component
-import com.yahoo.container.servlet.jersey.util.ResourceConfigUtil.registerComponent
-import org.eclipse.jetty.servlet.ServletHolder
-import org.glassfish.hk2.api.{InjectionResolver, TypeLiteral}
-import org.glassfish.hk2.utilities.Binder
-import org.glassfish.hk2.utilities.binding.AbstractBinder
-import org.glassfish.jersey.media.multipart.MultiPartFeature
-import org.glassfish.jersey.server.ResourceConfig
-import org.glassfish.jersey.servlet.ServletContainer
-import org.objectweb.asm.ClassReader
-
-import scala.collection.JavaConverters._
-import scala.util.control.Exception
-
-
-/**
- * @author tonytv
- */
-class JerseyServletProvider(restApiContext: RestApiContext) extends Provider[ServletHolder] {
- private val jerseyServletHolder = new ServletHolder(new ServletContainer(resourceConfig(restApiContext)))
-
- private def resourceConfig(restApiContext: RestApiContext) = {
- val resourceConfig = ResourceConfig.forApplication(
- new JerseyApplication(resourcesAndProviders(restApiContext.getBundles.asScala)))
-
- registerComponent(resourceConfig, componentInjectorBinder(restApiContext))
- registerComponent(resourceConfig, jacksonDatatypeJdk8Provider)
- resourceConfig.register(classOf[MultiPartFeature])
-
- resourceConfig
- }
-
- def resourcesAndProviders(bundles: Traversable[BundleInfo]) =
- (for {
- bundle <- bundles.view
- classEntry <- bundle.getClassEntries.asScala
- className <- detectResourceOrProvider(bundle.classLoader, classEntry)
- } yield loadClass(bundle.symbolicName, bundle.classLoader, className)).toSet
-
-
- def detectResourceOrProvider(bundleClassLoader: ClassLoader, classEntry: String): Option[String] = {
- using(getResourceAsStream(bundleClassLoader, classEntry)) { inputStream =>
- val visitor = ResourceOrProviderClassVisitor.visit(new ClassReader(inputStream))
- visitor.getJerseyClassName
- }
- }
-
- private def getResourceAsStream(bundleClassLoader: ClassLoader, classEntry: String) = {
- bundleClassLoader.getResourceAsStream(classEntry) match {
- case null => throw new RuntimeException(s"No entry $classEntry in bundle $bundleClassLoader")
- case stream => stream
- }
-
- }
-
- def using[T <: InputStream, R](stream: T)(f: T => R): R = {
- try {
- f(stream)
- } finally {
- Exception.ignoring(classOf[IOException]) {
- stream.close()
- }
- }
- }
-
- def loadClass(bundleSymbolicName: String, classLoader: ClassLoader, className: String) = {
- try {
- classLoader.loadClass(className)
- } catch {
- case e: Exception => throw new RuntimeException(s"Failed loading class $className from bundle $bundleSymbolicName", e)
- }
- }
-
- def componentInjectorBinder(restApiContext: RestApiContext): Binder = {
- val componentGraphProvider = new ComponentGraphProvider(restApiContext.getInjectableComponents.asScala)
- val componentAnnotationType = new TypeLiteral[InjectionResolver[Component]] {}
-
- new AbstractBinder {
- override def configure() {
- bind(componentGraphProvider).to(componentAnnotationType)
- }
- }
- }
-
- def jacksonDatatypeJdk8Provider: JacksonJaxbJsonProvider = {
- val provider = new JacksonJaxbJsonProvider()
- provider.setMapper(
- new ObjectMapper()
- .registerModule(new Jdk8Module)
- .registerModule(new JavaTimeModule))
- provider
- }
-
- override def get() = jerseyServletHolder
- override def deconstruct() {}
-}
-
diff --git a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.scala b/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.scala
deleted file mode 100644
index c015f11360e..00000000000
--- a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.scala
+++ /dev/null
@@ -1,74 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.container.servlet.jersey
-
-import javax.ws.rs.Path
-import javax.ws.rs.ext.Provider
-
-import org.objectweb.asm.{ClassVisitor, Opcodes, Type, AnnotationVisitor, ClassReader}
-
-
-/**
- * @author tonytv
- */
-class ResourceOrProviderClassVisitor private () extends ClassVisitor(Opcodes.ASM5) {
- private var className: String = null
- private var isPublic: Boolean = false
- private var isAbstract = false
-
- private var isInnerClass: Boolean = false
- private var isStatic: Boolean = false
-
- private var isAnnotated: Boolean = false
-
- def getJerseyClassName: Option[String] = {
- if (isJerseyClass) Some(getClassName)
- else None
- }
-
- def isJerseyClass: Boolean = {
- isAnnotated && isPublic && !isAbstract &&
- (!isInnerClass || isStatic)
- }
-
- def getClassName = {
- assert (className != null)
- Type.getObjectType(className).getClassName
- }
-
- override def visit(version: Int, access: Int, name: String, signature: String, superName: String, interfaces: Array[String]) {
- isPublic = ResourceOrProviderClassVisitor.isPublic(access)
- className = name
- isAbstract = ResourceOrProviderClassVisitor.isAbstract(access)
- }
-
- override def visitInnerClass(name: String, outerName: String, innerName: String, access: Int) {
- assert (className != null)
-
- if (name == className) {
- isInnerClass = true
- isStatic = ResourceOrProviderClassVisitor.isStatic(access)
- }
- }
-
- override def visitAnnotation(desc: String, visible: Boolean): AnnotationVisitor = {
- isAnnotated |= ResourceOrProviderClassVisitor.annotationClassDescriptors(desc)
- null
- }
-}
-
-
-object ResourceOrProviderClassVisitor {
- val annotationClassDescriptors = Set(classOf[Path], classOf[Provider]) map Type.getDescriptor
-
- def isPublic = isSet(Opcodes.ACC_PUBLIC) _
- def isStatic = isSet(Opcodes.ACC_STATIC) _
- def isAbstract = isSet(Opcodes.ACC_ABSTRACT) _
-
- private def isSet(bits: Int)(access: Int): Boolean = (access & bits) == bits
-
- def visit(classReader: ClassReader): ResourceOrProviderClassVisitor = {
- val visitor = new ResourceOrProviderClassVisitor
- classReader.accept(visitor, ClassReader.SKIP_DEBUG | ClassReader.SKIP_CODE | ClassReader.SKIP_FRAMES)
- visitor
- }
-}
diff --git a/container/pom.xml b/container/pom.xml
index d252a5eee4a..32a7947d6d5 100644
--- a/container/pom.xml
+++ b/container/pom.xml
@@ -21,6 +21,12 @@
<groupId>com.yahoo.vespa</groupId>
<artifactId>container-dev</artifactId>
<version>${project.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.ow2.asm</groupId>
+ <artifactId>asm</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/Organization.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/Organization.java
index 00c0d87554a..776002f31cb 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/Organization.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/Organization.java
@@ -3,7 +3,6 @@ package com.yahoo.vespa.hosted.controller.api.integration.organization;
import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId;
-import java.io.UncheckedIOException;
import java.net.URI;
import java.time.Duration;
import java.util.List;
@@ -87,8 +86,9 @@ public interface Organization {
*
* @param issueId ID of the issue to escalate.
* @param propertyId PropertyId of the tenant owning the application for which the issue was filed.
+ * @return User that was assigned issue as a result of the escalation, if any
*/
- default boolean escalate(IssueId issueId, PropertyId propertyId) {
+ default Optional<User> escalate(IssueId issueId, PropertyId propertyId) {
List<? extends List<? extends User>> contacts = contactsFor(propertyId);
Optional<User> assignee = assigneeOf(issueId);
@@ -101,9 +101,9 @@ public interface Organization {
for (int level = assigneeLevel + 1; level < contacts.size(); level++)
for (User target : contacts.get(level))
if (reassign(issueId, target))
- return true;
+ return Optional.of(target);
- return false;
+ return Optional.empty();
}
/**
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java
index 295b1adbca9..295e0102782 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java
@@ -10,6 +10,7 @@ import com.yahoo.config.provision.Environment;
import com.yahoo.vespa.hosted.controller.api.integration.MetricsService.ApplicationMetrics;
import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueId;
import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId;
+import com.yahoo.vespa.hosted.controller.application.ApplicationActivity;
import com.yahoo.vespa.hosted.controller.application.ApplicationRotation;
import com.yahoo.vespa.hosted.controller.application.ApplicationVersion;
import com.yahoo.vespa.hosted.controller.application.Change;
@@ -142,14 +143,21 @@ public class Application {
*/
public Change outstandingChange() { return outstandingChange; }
+ /** Returns ID of the last ownership issue filed for this */
public Optional<IssueId> ownershipIssueId() {
return ownershipIssueId;
}
+ /** Returns metrics for this */
public ApplicationMetrics metrics() {
return metrics;
}
+ /** Returns activity for this */
+ public ApplicationActivity activity() {
+ return ApplicationActivity.from(deployments.values());
+ }
+
/**
* Returns the oldest platform version this has deployed in a permanent zone (not test or staging).
*/
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java
index 8b0dc35e16b..f0e278c3e6d 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java
@@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableList;
import com.yahoo.component.Version;
import com.yahoo.config.application.api.DeploymentSpec;
import com.yahoo.config.application.api.ValidationId;
+import com.yahoo.config.application.api.ValidationOverrides;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.Environment;
import com.yahoo.config.provision.TenantName;
@@ -109,7 +110,7 @@ public class ApplicationController {
this.artifactRepository = artifactRepository;
this.rotationRepository = new RotationRepository(rotationsConfig, this, curator);
- this.deploymentTrigger = new DeploymentTrigger(controller, curator, buildService, clock);
+ this.deploymentTrigger = new DeploymentTrigger(controller, buildService, clock);
for (Application application : curator.readApplications()) {
lockIfPresent(application.id(), this::store);
@@ -256,7 +257,7 @@ public class ApplicationController {
LockedApplication application = new LockedApplication(new Application(id), lock);
store(application);
log.info("Created " + application);
- return application;
+ return application.get();
}
}
@@ -285,7 +286,7 @@ public class ApplicationController {
} else {
JobType jobType = JobType.from(controller.system(), zone)
.orElseThrow(() -> new IllegalArgumentException("No job found for zone " + zone));
- Optional<JobStatus> job = Optional.ofNullable(application.deploymentJobs().jobStatus().get(jobType));
+ Optional<JobStatus> job = Optional.ofNullable(application.get().deploymentJobs().jobStatus().get(jobType));
if ( ! job.isPresent()
|| ! job.get().lastTriggered().isPresent()
|| job.get().lastCompleted().isPresent() && job.get().lastCompleted().get().at().isAfter(job.get().lastTriggered().get().at()))
@@ -297,8 +298,8 @@ public class ApplicationController {
applicationVersion = preferOldestVersion
? triggered.sourceApplication().orElse(triggered.application())
: triggered.application();
- applicationPackage = new ApplicationPackage(artifactRepository.getApplicationPackage(application.id(), applicationVersion.id()));
- validateRun(application, zone, platformVersion, applicationVersion);
+ applicationPackage = new ApplicationPackage(artifactRepository.getApplicationPackage(application.get().id(), applicationVersion.id()));
+ validateRun(application.get(), zone, platformVersion, applicationVersion);
}
validate(applicationPackage.deploymentSpec());
@@ -323,7 +324,7 @@ public class ApplicationController {
application = withRotation(application, zone);
Set<String> rotationNames = new HashSet<>();
Set<String> cnames = new HashSet<>();
- application.rotation().ifPresent(applicationRotation -> {
+ application.get().rotation().ifPresent(applicationRotation -> {
rotationNames.add(applicationRotation.id().asString());
cnames.add(applicationRotation.dnsName());
cnames.add(applicationRotation.secureDnsName());
@@ -366,15 +367,15 @@ public class ApplicationController {
/** Makes sure the application has a global rotation, if eligible. */
private LockedApplication withRotation(LockedApplication application, ZoneId zone) {
- if (zone.environment() == Environment.prod && application.deploymentSpec().globalServiceId().isPresent()) {
+ if (zone.environment() == Environment.prod && application.get().deploymentSpec().globalServiceId().isPresent()) {
try (RotationLock rotationLock = rotationRepository.lock()) {
- Rotation rotation = rotationRepository.getOrAssignRotation(application, rotationLock);
+ Rotation rotation = rotationRepository.getOrAssignRotation(application.get(), rotationLock);
application = application.with(rotation.id());
store(application); // store assigned rotation even if deployment fails
- registerRotationInDns(rotation, application.rotation().get().dnsName());
- registerRotationInDns(rotation, application.rotation().get().secureDnsName());
- registerRotationInDns(rotation, application.rotation().get().oathDnsName());
+ registerRotationInDns(rotation, application.get().rotation().get().dnsName());
+ registerRotationInDns(rotation, application.get().rotation().get().secureDnsName());
+ registerRotationInDns(rotation, application.get().rotation().get().oathDnsName());
}
}
return application;
@@ -394,22 +395,23 @@ public class ApplicationController {
}
private LockedApplication deleteRemovedDeployments(LockedApplication application) {
- List<Deployment> deploymentsToRemove = application.productionDeployments().values().stream()
- .filter(deployment -> ! application.deploymentSpec().includes(deployment.zone().environment(),
- Optional.of(deployment.zone().region())))
+ List<Deployment> deploymentsToRemove = application.get().productionDeployments().values().stream()
+ .filter(deployment -> ! application.get().deploymentSpec().includes(deployment.zone().environment(),
+ Optional.of(deployment.zone().region())))
.collect(Collectors.toList());
if (deploymentsToRemove.isEmpty()) return application;
- if ( ! application.validationOverrides().allows(ValidationId.deploymentRemoval, clock.instant()))
- throw new IllegalArgumentException(ValidationId.deploymentRemoval.value() + ": " + application +
+ if ( ! application.get().validationOverrides().allows(ValidationId.deploymentRemoval, clock.instant()))
+ throw new IllegalArgumentException(ValidationId.deploymentRemoval.value() + ": " + application.get() +
" is deployed in " +
deploymentsToRemove.stream()
.map(deployment -> deployment.zone().region().value())
.collect(Collectors.joining(", ")) +
", but does not include " +
(deploymentsToRemove.size() > 1 ? "these zones" : "this zone") +
- " in deployment.xml");
+ " in deployment.xml. " +
+ ValidationOverrides.toAllowMessage(ValidationId.deploymentRemoval));
LockedApplication applicationWithRemoval = application;
for (Deployment deployment : deploymentsToRemove)
@@ -418,10 +420,11 @@ public class ApplicationController {
}
private LockedApplication deleteUnreferencedDeploymentJobs(LockedApplication application) {
- for (JobType job : application.deploymentJobs().jobStatus().keySet()) {
+ for (JobType job : application.get().deploymentJobs().jobStatus().keySet()) {
Optional<ZoneId> zone = job.zone(controller.system());
- if ( ! job.isProduction() || (zone.isPresent() && application.deploymentSpec().includes(zone.get().environment(), zone.map(ZoneId::region))))
+ if ( ! job.isProduction() || (zone.isPresent() && application.get().deploymentSpec().includes(
+ zone.get().environment(), zone.map(ZoneId::region))))
continue;
application = application.withoutDeploymentJob(job);
}
@@ -493,7 +496,7 @@ public class ApplicationController {
// TODO: Make this one transaction when database is moved to ZooKeeper
instances.forEach(id -> lockOrThrow(id, application -> {
- if ( ! application.deployments().isEmpty())
+ if ( ! application.get().deployments().isEmpty())
throw new IllegalArgumentException("Could not delete '" + application + "': It has active deployments");
Tenant tenant = controller.tenants().tenant(id.tenant()).get();
@@ -518,7 +521,7 @@ public class ApplicationController {
* @param application a locked application to store
*/
public void store(LockedApplication application) {
- curator.writeApplication(application);
+ curator.writeApplication(application.get());
}
/**
@@ -572,7 +575,7 @@ public class ApplicationController {
*/
private LockedApplication deactivate(LockedApplication application, ZoneId zone) {
try {
- configServer.deactivate(new DeploymentId(application.id(), zone));
+ configServer.deactivate(new DeploymentId(application.get().id(), zone));
}
catch (NoInstanceException ignored) {
// ok; already gone
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java
index 913adf06f22..3207d4b8399 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java
@@ -26,17 +26,29 @@ import com.yahoo.vespa.hosted.controller.rotation.RotationId;
import java.time.Instant;
import java.util.LinkedHashMap;
import java.util.Map;
+import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong;
/**
- * A combination of an application instance and a lock for that application. Provides methods for updating application
- * fields.
+ * An application that has been locked for modification. Provides methods for modifying an application's fields.
*
* @author mpolden
* @author jvenstad
*/
-public class LockedApplication extends Application {
+public class LockedApplication {
+
+ private final Lock lock;
+ private final ApplicationId id;
+ private final DeploymentSpec deploymentSpec;
+ private final ValidationOverrides validationOverrides;
+ private final Map<ZoneId, Deployment> deployments;
+ private final DeploymentJobs deploymentJobs;
+ private final Change change;
+ private final Change outstandingChange;
+ private final Optional<IssueId> ownershipIssueId;
+ private final ApplicationMetrics metrics;
+ private final Optional<RotationId> rotation;
/**
* Used to create a locked application
@@ -44,180 +56,172 @@ public class LockedApplication extends Application {
* @param application The application to lock.
* @param lock The lock for the application.
*/
- LockedApplication(Application application, @SuppressWarnings("unused") Lock lock) {
- this(new Builder(application));
- }
-
- private LockedApplication(Builder builder) {
- super(builder.applicationId, builder.deploymentSpec, builder.validationOverrides,
- builder.deployments, builder.deploymentJobs, builder.deploying,
- builder.outstandingChange, builder.ownershipIssueId, builder.metrics, builder.rotation);
+ LockedApplication(Application application, Lock lock) {
+ this(Objects.requireNonNull(lock, "lock cannot be null"), application.id(),
+ application.deploymentSpec(), application.validationOverrides(),
+ application.deployments(),
+ application.deploymentJobs(), application.change(), application.outstandingChange(),
+ application.ownershipIssueId(), application.metrics(),
+ application.rotation().map(ApplicationRotation::id));
+ }
+
+ private LockedApplication(Lock lock, ApplicationId id,
+ DeploymentSpec deploymentSpec, ValidationOverrides validationOverrides,
+ Map<ZoneId, Deployment> deployments, DeploymentJobs deploymentJobs, Change change,
+ Change outstandingChange, Optional<IssueId> ownershipIssueId, ApplicationMetrics metrics,
+ Optional<RotationId> rotation) {
+ this.lock = lock;
+ this.id = id;
+ this.deploymentSpec = deploymentSpec;
+ this.validationOverrides = validationOverrides;
+ this.deployments = deployments;
+ this.deploymentJobs = deploymentJobs;
+ this.change = change;
+ this.outstandingChange = outstandingChange;
+ this.ownershipIssueId = ownershipIssueId;
+ this.metrics = metrics;
+ this.rotation = rotation;
+ }
+
+ /** Returns a read-only copy of this */
+ public Application get() {
+ return new Application(id, deploymentSpec, validationOverrides, deployments, deploymentJobs, change,
+ outstandingChange, ownershipIssueId, metrics, rotation);
}
public LockedApplication withProjectId(OptionalLong projectId) {
- return new LockedApplication(new Builder(this).with(deploymentJobs().withProjectId(projectId)));
+ return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments,
+ deploymentJobs.withProjectId(projectId), change, outstandingChange,
+ ownershipIssueId, metrics, rotation);
}
public LockedApplication withDeploymentIssueId(IssueId issueId) {
- return new LockedApplication(new Builder(this).with(deploymentJobs().with(issueId)));
+ return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments,
+ deploymentJobs.with(issueId), change, outstandingChange,
+ ownershipIssueId, metrics, rotation);
}
- public LockedApplication withJobCompletion(long projectId, JobType jobType, JobStatus.JobRun completion, Optional<DeploymentJobs.JobError> jobError) {
- return new LockedApplication(new Builder(this).with(deploymentJobs().withCompletion(projectId, jobType, completion, jobError))
- );
+ public LockedApplication withJobCompletion(long projectId, JobType jobType, JobStatus.JobRun completion,
+ Optional<DeploymentJobs.JobError> jobError) {
+ return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments,
+ deploymentJobs.withCompletion(projectId, jobType, completion, jobError),
+ change, outstandingChange, ownershipIssueId, metrics, rotation);
}
public LockedApplication withJobTriggering(JobType jobType, JobStatus.JobRun job) {
- return new LockedApplication(new Builder(this).with(deploymentJobs().withTriggering(jobType, job)));
+ return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments,
+ deploymentJobs.withTriggering(jobType, job), change, outstandingChange,
+ ownershipIssueId, metrics, rotation);
}
public LockedApplication withNewDeployment(ZoneId zone, ApplicationVersion applicationVersion, Version version,
Instant instant) {
// Use info from previous deployment if available, otherwise create a new one.
- Deployment previousDeployment = deployments().getOrDefault(zone, new Deployment(zone, applicationVersion,
- version, instant));
+ Deployment previousDeployment = deployments.getOrDefault(zone, new Deployment(zone, applicationVersion,
+ version, instant));
Deployment newDeployment = new Deployment(zone, applicationVersion, version, instant,
previousDeployment.clusterUtils(),
previousDeployment.clusterInfo(),
- previousDeployment.metrics());
+ previousDeployment.metrics(),
+ previousDeployment.activity());
return with(newDeployment);
}
public LockedApplication withClusterUtilization(ZoneId zone, Map<ClusterSpec.Id, ClusterUtilization> clusterUtilization) {
- Deployment deployment = deployments().get(zone);
+ Deployment deployment = deployments.get(zone);
if (deployment == null) return this; // No longer deployed in this zone.
return with(deployment.withClusterUtils(clusterUtilization));
}
public LockedApplication withClusterInfo(ZoneId zone, Map<ClusterSpec.Id, ClusterInfo> clusterInfo) {
- Deployment deployment = deployments().get(zone);
+ Deployment deployment = deployments.get(zone);
if (deployment == null) return this; // No longer deployed in this zone.
return with(deployment.withClusterInfo(clusterInfo));
}
+ public LockedApplication recordActivityAt(Instant instant, ZoneId zone) {
+ Deployment deployment = deployments.get(zone);
+ if (deployment == null) return this;
+ return with(deployment.recordActivityAt(instant));
+ }
+
public LockedApplication with(ZoneId zone, DeploymentMetrics deploymentMetrics) {
- Deployment deployment = deployments().get(zone);
+ Deployment deployment = deployments.get(zone);
if (deployment == null) return this; // No longer deployed in this zone.
return with(deployment.withMetrics(deploymentMetrics));
}
public LockedApplication withoutDeploymentIn(ZoneId zone) {
- Map<ZoneId, Deployment> deployments = new LinkedHashMap<>(deployments());
+ Map<ZoneId, Deployment> deployments = new LinkedHashMap<>(this.deployments);
deployments.remove(zone);
- return new LockedApplication(new Builder(this).with(deployments));
+ return with(deployments);
}
public LockedApplication withoutDeploymentJob(DeploymentJobs.JobType jobType) {
- return new LockedApplication(new Builder(this).with(deploymentJobs().without(jobType)));
+ return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments,
+ deploymentJobs.without(jobType), change, outstandingChange,
+ ownershipIssueId, metrics, rotation);
}
public LockedApplication with(DeploymentSpec deploymentSpec) {
- return new LockedApplication(new Builder(this).with(deploymentSpec));
+ return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments,
+ deploymentJobs, change, outstandingChange,
+ ownershipIssueId, metrics, rotation);
}
public LockedApplication with(ValidationOverrides validationOverrides) {
- return new LockedApplication(new Builder(this).with(validationOverrides));
+ return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments,
+ deploymentJobs, change, outstandingChange,
+ ownershipIssueId, metrics, rotation);
}
public LockedApplication withChange(Change change) {
- return new LockedApplication(new Builder(this).withChange(change));
+ return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments,
+ deploymentJobs, change, outstandingChange,
+ ownershipIssueId, metrics, rotation);
}
public LockedApplication withOutstandingChange(Change outstandingChange) {
- return new LockedApplication(new Builder(this).withOutstandingChange(outstandingChange));
+ return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments,
+ deploymentJobs, change, outstandingChange,
+ ownershipIssueId, metrics, rotation);
}
public LockedApplication withOwnershipIssueId(IssueId issueId) {
- return new LockedApplication(new Builder(this).withOwnershipIssueId(Optional.ofNullable(issueId)));
+ return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments,
+ deploymentJobs, change, outstandingChange,
+ Optional.ofNullable(issueId), metrics, rotation);
}
public LockedApplication with(MetricsService.ApplicationMetrics metrics) {
- return new LockedApplication(new Builder(this).with(metrics));
+ return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments,
+ deploymentJobs, change, outstandingChange,
+ ownershipIssueId, metrics, rotation);
}
public LockedApplication with(RotationId rotation) {
- return new LockedApplication(new Builder(this).with(rotation));
+ return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments,
+ deploymentJobs, change, outstandingChange,
+ ownershipIssueId, metrics, Optional.of(rotation));
}
/** Don't expose non-leaf sub-objects. */
private LockedApplication with(Deployment deployment) {
- Map<ZoneId, Deployment> deployments = new LinkedHashMap<>(deployments());
+ Map<ZoneId, Deployment> deployments = new LinkedHashMap<>(this.deployments);
deployments.put(deployment.zone(), deployment);
- return new LockedApplication(new Builder(this).with(deployments));
- }
-
- private static class Builder {
-
- private final ApplicationId applicationId;
- private DeploymentSpec deploymentSpec;
- private ValidationOverrides validationOverrides;
- private Map<ZoneId, Deployment> deployments;
- private DeploymentJobs deploymentJobs;
- private Change deploying;
- private Change outstandingChange;
- private Optional<IssueId> ownershipIssueId;
- private ApplicationMetrics metrics;
- private Optional<RotationId> rotation;
-
- private Builder(Application application) {
- this.applicationId = application.id();
- this.deploymentSpec = application.deploymentSpec();
- this.validationOverrides = application.validationOverrides();
- this.deployments = application.deployments();
- this.deploymentJobs = application.deploymentJobs();
- this.deploying = application.change();
- this.outstandingChange = application.outstandingChange();
- this.ownershipIssueId = application.ownershipIssueId();
- this.metrics = application.metrics();
- this.rotation = application.rotation().map(ApplicationRotation::id);
- }
-
- private Builder with(DeploymentSpec deploymentSpec) {
- this.deploymentSpec = deploymentSpec;
- return this;
- }
-
- private Builder with(ValidationOverrides validationOverrides) {
- this.validationOverrides = validationOverrides;
- return this;
- }
-
- private Builder with(Map<ZoneId, Deployment> deployments) {
- this.deployments = deployments;
- return this;
- }
-
- private Builder with(DeploymentJobs deploymentJobs) {
- this.deploymentJobs = deploymentJobs;
- return this;
- }
-
- private Builder withChange(Change deploying) {
- this.deploying = deploying;
- return this;
- }
-
- private Builder withOutstandingChange(Change outstandingChange) {
- this.outstandingChange = outstandingChange;
- return this;
- }
-
- private Builder withOwnershipIssueId(Optional<IssueId> ownershipIssueId) {
- this.ownershipIssueId = ownershipIssueId;
- return this;
- }
-
- private Builder with(ApplicationMetrics metrics) {
- this.metrics = metrics;
- return this;
- }
-
- private Builder with(RotationId rotation) {
- this.rotation = Optional.of(rotation);
- return this;
- }
+ return with(deployments);
+ }
+
+ private LockedApplication with(Map<ZoneId, Deployment> deployments) {
+ return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments,
+ deploymentJobs, change, outstandingChange,
+ ownershipIssueId, metrics, rotation);
+ }
+ @Override
+ public String toString() {
+ return "application '" + id + "'";
}
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationActivity.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationActivity.java
new file mode 100644
index 00000000000..ddd519382a6
--- /dev/null
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationActivity.java
@@ -0,0 +1,56 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.application;
+
+import java.time.Instant;
+import java.util.Collection;
+import java.util.Comparator;
+import java.util.Optional;
+import java.util.function.Function;
+
+/**
+ * Recent activity in an application.
+ *
+ * @author mpolden
+ */
+public class ApplicationActivity {
+
+ public static final ApplicationActivity none = new ApplicationActivity(Optional.empty(), Optional.empty());
+
+ private final Optional<Instant> lastQueried;
+ private final Optional<Instant> lastWritten;
+
+ private ApplicationActivity(Optional<Instant> lastQueried, Optional<Instant> lastWritten) {
+ this.lastQueried = lastQueried;
+ this.lastWritten = lastWritten;
+ }
+
+ /** The last time any deployment in this was queried */
+ public Optional<Instant> lastQueried() {
+ return lastQueried;
+ }
+
+ /** The last time any deployment in this was written */
+ public Optional<Instant> lastWritten() {
+ return lastWritten;
+ }
+
+ public static ApplicationActivity from(Collection<Deployment> deployments) {
+ Optional<Instant> lastQueried = lastActivity(deployments, DeploymentActivity::lastQueried);
+ Optional<Instant> lastWritten = lastActivity(deployments, DeploymentActivity::lastWritten);
+ if (!lastQueried.isPresent() && !lastWritten.isPresent()) {
+ return none;
+ }
+ return new ApplicationActivity(lastQueried, lastWritten);
+ }
+
+ private static Optional<Instant> lastActivity(Collection<Deployment> deployments,
+ Function<DeploymentActivity, Optional<Instant>> activityField) {
+ return deployments.stream()
+ .map(Deployment::activity)
+ .map(activityField)
+ .filter(Optional::isPresent)
+ .map(Optional::get)
+ .max(Comparator.naturalOrder());
+ }
+
+}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationPackage.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationPackage.java
index 6df8e901653..40e2e4a92d1 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationPackage.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationPackage.java
@@ -34,9 +34,8 @@ public class ApplicationPackage {
* it must not be further changed by the caller.
*/
public ApplicationPackage(byte[] zippedContent) {
- Objects.requireNonNull(zippedContent, "The application package content cannot be null");
+ this.zippedContent = Objects.requireNonNull(zippedContent, "The application package content cannot be null");
this.contentHash = DigestUtils.shaHex(zippedContent);
- this.zippedContent = zippedContent;
this.deploymentSpec = extractFile("deployment.xml", zippedContent).map(DeploymentSpec::fromXml).orElse(DeploymentSpec.empty);
this.validationOverrides = extractFile("validation-overrides.xml", zippedContent).map(ValidationOverrides::fromXml).orElse(ValidationOverrides.empty);
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java
index 8fa0c6da49c..0a062427a8a 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java
@@ -6,6 +6,7 @@ import com.yahoo.config.provision.ClusterSpec.Id;
import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId;
import java.time.Instant;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
@@ -25,27 +26,25 @@ public class Deployment {
private final Map<Id, ClusterUtilization> clusterUtils;
private final Map<Id, ClusterInfo> clusterInfo;
private final DeploymentMetrics metrics;
+ private final DeploymentActivity activity;
public Deployment(ZoneId zone, ApplicationVersion applicationVersion, Version version, Instant deployTime) {
- this(zone, applicationVersion, version, deployTime, new HashMap<>(), new HashMap<>(), new DeploymentMetrics());
+ this(zone, applicationVersion, version, deployTime, Collections.emptyMap(), Collections.emptyMap(),
+ new DeploymentMetrics(), DeploymentActivity.none);
}
public Deployment(ZoneId zone, ApplicationVersion applicationVersion, Version version, Instant deployTime,
- Map<Id, ClusterUtilization> clusterUtils, Map<Id, ClusterInfo> clusterInfo, DeploymentMetrics metrics) {
- Objects.requireNonNull(zone, "zone cannot be null");
- Objects.requireNonNull(applicationVersion, "applicationVersion cannot be null");
- Objects.requireNonNull(version, "version cannot be null");
- Objects.requireNonNull(deployTime, "deployTime cannot be null");
- Objects.requireNonNull(clusterUtils, "clusterUtils cannot be null");
- Objects.requireNonNull(clusterInfo, "clusterInfo cannot be null");
- Objects.requireNonNull(metrics, "deployment metrics cannot be null");
- this.zone = zone;
- this.applicationVersion = applicationVersion;
- this.version = version;
- this.deployTime = deployTime;
- this.clusterUtils = clusterUtils;
- this.clusterInfo = clusterInfo;
- this.metrics = metrics;
+ Map<Id, ClusterUtilization> clusterUtils, Map<Id, ClusterInfo> clusterInfo,
+ DeploymentMetrics metrics,
+ DeploymentActivity activity) {
+ this.zone = Objects.requireNonNull(zone, "zone cannot be null");
+ this.applicationVersion = Objects.requireNonNull(applicationVersion, "applicationVersion cannot be null");
+ this.version = Objects.requireNonNull(version, "version cannot be null");
+ this.deployTime = Objects.requireNonNull(deployTime, "deployTime cannot be null");
+ this.clusterUtils = Objects.requireNonNull(clusterUtils, "clusterUtils cannot be null");
+ this.clusterInfo = Objects.requireNonNull(clusterInfo, "clusterInfo cannot be null");
+ this.metrics = Objects.requireNonNull(metrics, "deploymentMetrics cannot be null");
+ this.activity = Objects.requireNonNull(activity, "activity cannot be null");
}
/** Returns the zone this was deployed to */
@@ -60,29 +59,42 @@ public class Deployment {
/** Returns the time this was deployed */
public Instant at() { return deployTime; }
+ /** Returns metrics for this */
+ public DeploymentMetrics metrics() {
+ return metrics;
+ }
+
+ /** Returns activity for this */
+ public DeploymentActivity activity() { return activity; }
+
+ /** Returns information about the clusters allocated to this */
public Map<Id, ClusterInfo> clusterInfo() {
return clusterInfo;
}
+ /** Returns utilization of the clusters allocated to this */
public Map<Id, ClusterUtilization> clusterUtils() {
return clusterUtils;
}
+ public Deployment recordActivityAt(Instant instant) {
+ return new Deployment(zone, applicationVersion, version, deployTime, clusterUtils, clusterInfo, metrics,
+ activity.recordAt(instant, metrics));
+ }
+
public Deployment withClusterUtils(Map<Id, ClusterUtilization> clusterUtilization) {
- return new Deployment(zone, applicationVersion, version, deployTime, clusterUtilization, clusterInfo, metrics);
+ return new Deployment(zone, applicationVersion, version, deployTime, clusterUtilization, clusterInfo, metrics,
+ activity);
}
public Deployment withClusterInfo(Map<Id, ClusterInfo> newClusterInfo) {
- return new Deployment(zone, applicationVersion, version, deployTime, clusterUtils, newClusterInfo, metrics);
+ return new Deployment(zone, applicationVersion, version, deployTime, clusterUtils, newClusterInfo, metrics,
+ activity);
}
public Deployment withMetrics(DeploymentMetrics metrics) {
- return new Deployment(zone, applicationVersion, version, deployTime, clusterUtils, clusterInfo, metrics);
- }
-
- /** @return Key metrics for the deployment (application level) like QPS and document count */
- public DeploymentMetrics metrics() {
- return metrics;
+ return new Deployment(zone, applicationVersion, version, deployTime, clusterUtils, clusterInfo, metrics,
+ activity);
}
/**
@@ -109,4 +121,5 @@ public class Deployment {
public String toString() {
return "deployment to " + zone + " of " + applicationVersion + " on version " + version + " at " + deployTime;
}
+
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentActivity.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentActivity.java
new file mode 100644
index 00000000000..d4635212e80
--- /dev/null
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentActivity.java
@@ -0,0 +1,55 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.application;
+
+import java.time.Instant;
+import java.util.Objects;
+import java.util.Optional;
+
+/**
+ * Recent activity in a deployment.
+ *
+ * @author mpolden
+ */
+public class DeploymentActivity {
+
+ /** Query rates at or below this threshold indicate inactivity */
+ private static final double inactivityThreshold = 0;
+
+ public static final DeploymentActivity none = new DeploymentActivity(Optional.empty(), Optional.empty());
+
+ private final Optional<Instant> lastQueried;
+ private final Optional<Instant> lastWritten;
+
+ private DeploymentActivity(Optional<Instant> lastQueried, Optional<Instant> lastWritten) {
+ this.lastQueried = Objects.requireNonNull(lastQueried, "lastQueried must be non-null");
+ this.lastWritten = Objects.requireNonNull(lastWritten, "lastWritten must be non-null");
+ }
+
+ /** The last time this deployment received queries (search) */
+ public Optional<Instant> lastQueried() {
+ return lastQueried;
+ }
+
+ /** The last time this deployment received writes (feed) */
+ public Optional<Instant> lastWritten() {
+ return lastWritten;
+ }
+
+ /** Record activity using given metrics */
+ public DeploymentActivity recordAt(Instant instant, DeploymentMetrics metrics) {
+ return new DeploymentActivity(activityAt(instant, lastQueried, metrics.queriesPerSecond()),
+ activityAt(instant, lastWritten, metrics.writesPerSecond()));
+ }
+
+ public static DeploymentActivity create(Optional<Instant> queriedAt, Optional<Instant> writtenAt) {
+ if (!queriedAt.isPresent() && !writtenAt.isPresent()) {
+ return none;
+ }
+ return new DeploymentActivity(queriedAt, writtenAt);
+ }
+
+ private static Optional<Instant> activityAt(Instant newInstant, Optional<Instant> oldInstant, double rate) {
+ return rate > inactivityThreshold ? Optional.of(newInstant) : oldInstant;
+ }
+
+}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/concurrent/Locks.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/concurrent/Locks.java
deleted file mode 100644
index 6168812203a..00000000000
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/concurrent/Locks.java
+++ /dev/null
@@ -1,55 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.controller.concurrent;
-
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.locks.ReentrantLock;
-
-/**
- * Holds a map of locks indexed on keys of a given type.
- * This is suitable in cases where exclusive access should be granted to any one of a set of keyed objects and
- * there is a finite collection of keyed objects.
- *
- * The returned locks are reentrant (i.e the owning thread may call lock multiple times) and auto-closable.
- *
- * Typical use is
- * <code>
- * try (Lock lock = locks.lock(id)) {
- * exclusive use of the object with key id
- * }
- * </code>
- *
- * @author bratseth
- */
-public class Locks<TYPE> {
-
- private final Map<TYPE, ReentrantLock> locks = new ConcurrentHashMap<>();
-
- private final long timeoutMs;
-
- public Locks(int timeout, TimeUnit timeoutUnit) {
- timeoutMs = timeoutUnit.toMillis(timeout);
- }
-
- /**
- * Locks key. This will block until the key is acquired.
- * Users of this <b>must</b> close any lock acquired.
- *
- * @param key the key to lock
- * @return the acquired lock
- * @throws TimeoutException if the lock could not be acquired within the timeout
- */
- public Lock lock(TYPE key) {
- try {
- ReentrantLock lock = locks.computeIfAbsent(key, k -> new ReentrantLock(true));
- boolean acquired = lock.tryLock(timeoutMs, TimeUnit.MILLISECONDS);
- if ( ! acquired)
- throw new TimeoutException("Timed out waiting for the lock to " + key);
- return new Lock(lock);
- } catch (InterruptedException e) {
- throw new RuntimeException("Interrupted while waiting for lock of " + key);
- }
- }
-
-}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentOrder.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentOrder.java
index 405c8d17263..1c535a5a331 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentOrder.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentOrder.java
@@ -3,7 +3,6 @@ package com.yahoo.vespa.hosted.controller.deployment;
import com.yahoo.config.application.api.DeploymentSpec;
import com.yahoo.config.provision.SystemName;
-import com.yahoo.vespa.hosted.controller.Controller;
import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId;
import com.yahoo.vespa.hosted.controller.application.Deployment;
import com.yahoo.vespa.hosted.controller.application.DeploymentJobs;
@@ -30,8 +29,7 @@ public class DeploymentOrder {
private final Supplier<SystemName> system;
public DeploymentOrder(Supplier<SystemName> system) {
- Objects.requireNonNull(system, "system may not be null");
- this.system = system;
+ this.system = Objects.requireNonNull(system, "system may not be null");
}
/** Returns jobs for given deployment spec, in the order they are declared */
@@ -46,25 +44,25 @@ public class DeploymentOrder {
public List<JobStatus> sortBy(DeploymentSpec deploymentSpec, Collection<JobStatus> jobStatus) {
List<DeploymentJobs.JobType> sortedJobs = jobsFrom(deploymentSpec);
return jobStatus.stream()
- .sorted(comparingInt(job -> sortedJobs.indexOf(job.type())))
- .collect(collectingAndThen(toList(), Collections::unmodifiableList));
+ .sorted(comparingInt(job -> sortedJobs.indexOf(job.type())))
+ .collect(collectingAndThen(toList(), Collections::unmodifiableList));
}
/** Returns deployments sorted according to declared zones */
public List<Deployment> sortBy(List<DeploymentSpec.DeclaredZone> zones, Collection<Deployment> deployments) {
List<ZoneId> productionZones = zones.stream()
- .filter(z -> z.region().isPresent())
- .map(z -> ZoneId.from(z.environment(), z.region().get()))
- .collect(toList());
+ .filter(z -> z.region().isPresent())
+ .map(z -> ZoneId.from(z.environment(), z.region().get()))
+ .collect(toList());
return deployments.stream()
- .sorted(comparingInt(deployment -> productionZones.indexOf(deployment.zone())))
- .collect(collectingAndThen(toList(), Collections::unmodifiableList));
+ .sorted(comparingInt(deployment -> productionZones.indexOf(deployment.zone())))
+ .collect(collectingAndThen(toList(), Collections::unmodifiableList));
}
/** Resolve job from deployment step */
public JobType toJob(DeploymentSpec.DeclaredZone zone) {
return JobType.from(system.get(), zone.environment(), zone.region().orElse(null))
- .orElseThrow(() -> new IllegalArgumentException("Invalid zone " + zone));
+ .orElseThrow(() -> new IllegalArgumentException("Invalid zone " + zone));
}
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java
index e902206ad8b..63a6ac234ff 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java
@@ -20,7 +20,6 @@ import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobReport;
import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobType;
import com.yahoo.vespa.hosted.controller.application.JobStatus;
import com.yahoo.vespa.hosted.controller.application.JobStatus.JobRun;
-import com.yahoo.vespa.hosted.controller.persistence.CuratorDb;
import java.time.Clock;
import java.time.Duration;
@@ -78,14 +77,11 @@ public class DeploymentTrigger {
private final DeploymentOrder order;
private final BuildService buildService;
- public DeploymentTrigger(Controller controller, CuratorDb curator, BuildService buildService, Clock clock) {
- Objects.requireNonNull(controller, "controller cannot be null");
- Objects.requireNonNull(curator, "curator cannot be null");
- Objects.requireNonNull(clock, "clock cannot be null");
- this.controller = controller;
- this.clock = clock;
+ public DeploymentTrigger(Controller controller, BuildService buildService, Clock clock) {
+ this.controller = Objects.requireNonNull(controller, "controller cannot be null");
+ this.buildService = Objects.requireNonNull(buildService, "buildService cannot be null");
+ this.clock = Objects.requireNonNull(clock, "clock cannot be null");
this.order = new DeploymentOrder(controller::system);
- this.buildService = buildService;
}
public DeploymentOrder deploymentOrder() {
@@ -116,15 +112,15 @@ public class DeploymentTrigger {
triggering = JobRun.triggering(controller.systemVersion(), applicationVersion, Optional
.empty(), Optional.empty(), "Application commit", clock.instant());
if (report.success()) {
- if (acceptNewApplicationVersion(application))
- application = application.withChange(application.change().with(applicationVersion))
+ if (acceptNewApplicationVersion(application.get()))
+ application = application.withChange(application.get().change().with(applicationVersion))
.withOutstandingChange(Change.empty());
else
application = application.withOutstandingChange(Change.of(applicationVersion));
}
}
else {
- triggering = application.deploymentJobs().statusOf(report.jobType()).flatMap(JobStatus::lastTriggered)
+ triggering = application.get().deploymentJobs().statusOf(report.jobType()).flatMap(JobStatus::lastTriggered)
.orElseThrow(() -> new IllegalStateException("Notified of completion of " + report.jobType().jobName() + " for " +
report.applicationId() + ", but that has neither been triggered nor deployed"));
}
@@ -132,7 +128,7 @@ public class DeploymentTrigger {
report.jobType(),
triggering.completion(report.buildNumber(), clock.instant()),
report.jobError());
- application = application.withChange(remainingChange(application));
+ application = application.withChange(remainingChange(application.get()));
applications().store(application);
});
}
@@ -216,9 +212,9 @@ public class DeploymentTrigger {
*/
public void triggerChange(ApplicationId applicationId, Change change) {
applications().lockOrThrow(applicationId, application -> {
- if (application.changeAt(controller.clock().instant()).isPresent() && ! application.deploymentJobs().hasFailures())
+ if (application.get().changeAt(controller.clock().instant()).isPresent() && ! application.get().deploymentJobs().hasFailures())
throw new IllegalArgumentException("Could not start " + change + " on " + application + ": " +
- application.change() + " is already in progress");
+ application.get().change() + " is already in progress");
application = application.withChange(change);
if (change.application().isPresent())
application = application.withOutstandingChange(Change.empty());
@@ -230,7 +226,7 @@ public class DeploymentTrigger {
/** Cancels a platform upgrade of the given application, and an application upgrade as well if {@code keepApplicationChange}. */
public void cancelChange(ApplicationId applicationId, boolean keepApplicationChange) {
applications().lockOrThrow(applicationId, application -> {
- applications().store(application.withChange(application.change().application()
+ applications().store(application.withChange(application.get().change().application()
.filter(__ -> keepApplicationChange)
.map(Change::of)
.orElse(Change.empty())));
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainer.java
index 821efba013d..4dacb2e32d6 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainer.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainer.java
@@ -15,8 +15,8 @@ import java.util.logging.Level;
import java.util.logging.Logger;
/**
- * Retrieve deployment metrics like qps and document count from the metric service and
- * update the applications with this info.
+ * Retrieve deployment metrics such as QPS and document count from the metric service and
+ * update applications with this info.
*
* @author smorgrav
*/
@@ -39,19 +39,19 @@ public class DeploymentMetricsMaintainer extends Maintainer {
for (Deployment deployment : application.deployments().values()) {
MetricsService.DeploymentMetrics deploymentMetrics = controller().metricsService()
.getDeploymentMetrics(application.id(), deployment.zone());
- DeploymentMetrics appMetrics = new DeploymentMetrics(deploymentMetrics.queriesPerSecond(),
+ DeploymentMetrics newMetrics = new DeploymentMetrics(deploymentMetrics.queriesPerSecond(),
deploymentMetrics.writesPerSecond(),
deploymentMetrics.documentCount(),
deploymentMetrics.queryLatencyMillis(),
deploymentMetrics.writeLatencyMillis());
controller().applications().lockIfPresent(application.id(), lockedApplication ->
- controller().applications().store(lockedApplication.with(deployment.zone(), appMetrics)));
+ controller().applications().store(lockedApplication.with(deployment.zone(), newMetrics)
+ .recordActivityAt(controller().clock().instant(), deployment.zone())));
}
- }
- catch (UncheckedIOException e) {
+ } catch (UncheckedIOException e) {
if (!hasWarned) // produce only one warning per maintenance interval
- log.log(Level.WARNING, "Failed talking to YAMAS: " + Exceptions.toMessageString(e) +
+ log.log(Level.WARNING, "Failed to query metrics service: " + Exceptions.toMessageString(e) +
". Retrying in " + maintenanceInterval());
hasWarned = true;
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/Upgrader.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/Upgrader.java
index bd8b8fc8747..22cbe942932 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/Upgrader.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/Upgrader.java
@@ -7,7 +7,6 @@ import com.yahoo.vespa.curator.Lock;
import com.yahoo.vespa.hosted.controller.Application;
import com.yahoo.vespa.hosted.controller.Controller;
import com.yahoo.vespa.hosted.controller.application.ApplicationList;
-import com.yahoo.vespa.hosted.controller.application.Change;
import com.yahoo.vespa.hosted.controller.persistence.CuratorDb;
import com.yahoo.vespa.hosted.controller.versions.VespaVersion;
import com.yahoo.vespa.hosted.controller.versions.VespaVersion.Confidence;
@@ -18,6 +17,7 @@ import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -36,7 +36,7 @@ public class Upgrader extends Maintainer {
public Upgrader(Controller controller, Duration interval, JobControl jobControl, CuratorDb curator) {
super(controller, interval, jobControl);
- this.curator = curator;
+ this.curator = Objects.requireNonNull(curator, "curator cannot be null");
}
/**
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java
index 21eea21ba68..6ad2452b2a2 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java
@@ -20,6 +20,7 @@ import com.yahoo.vespa.hosted.controller.application.Change;
import com.yahoo.vespa.hosted.controller.application.ClusterInfo;
import com.yahoo.vespa.hosted.controller.application.ClusterUtilization;
import com.yahoo.vespa.hosted.controller.application.Deployment;
+import com.yahoo.vespa.hosted.controller.application.DeploymentActivity;
import com.yahoo.vespa.hosted.controller.application.DeploymentJobs;
import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobError;
import com.yahoo.vespa.hosted.controller.application.DeploymentMetrics;
@@ -68,6 +69,8 @@ public class ApplicationSerializer {
private final String repositoryField = "repositoryField";
private final String branchField = "branchField";
private final String commitField = "commitField";
+ private final String lastQueriedField = "lastQueried";
+ private final String lastWrittenField = "lastWritten";
// DeploymentJobs fields
private final String projectIdField = "projectId";
@@ -148,10 +151,12 @@ public class ApplicationSerializer {
toSlime(deployment.applicationVersion(), object.setObject(applicationPackageRevisionField));
clusterInfoToSlime(deployment.clusterInfo(), object);
clusterUtilsToSlime(deployment.clusterUtils(), object);
- metricsToSlime(deployment.metrics(), object);
+ deploymentMetricsToSlime(deployment.metrics(), object);
+ deployment.activity().lastQueried().ifPresent(instant -> object.setLong(lastQueriedField, instant.toEpochMilli()));
+ deployment.activity().lastWritten().ifPresent(instant -> object.setLong(lastWrittenField, instant.toEpochMilli()));
}
- private void metricsToSlime(DeploymentMetrics metrics, Cursor object) {
+ private void deploymentMetricsToSlime(DeploymentMetrics metrics, Cursor object) {
Cursor root = object.setObject(deploymentMetricsField);
root.setDouble(deploymentMetricsQPSField, metrics.queriesPerSecond());
root.setDouble(deploymentMetricsWPSField, metrics.writesPerSecond());
@@ -289,19 +294,17 @@ public class ApplicationSerializer {
Instant.ofEpochMilli(deploymentObject.field(deployTimeField).asLong()),
clusterUtilsMapFromSlime(deploymentObject.field(clusterUtilsField)),
clusterInfoMapFromSlime(deploymentObject.field(clusterInfoField)),
- deploymentMetricsFromSlime(deploymentObject.field(deploymentMetricsField)));
+ deploymentMetricsFromSlime(deploymentObject.field(deploymentMetricsField)),
+ DeploymentActivity.create(optionalInstant(deploymentObject.field(lastQueriedField)),
+ optionalInstant(deploymentObject.field(lastWrittenField))));
}
private DeploymentMetrics deploymentMetricsFromSlime(Inspector object) {
-
- double queriesPerSecond = object.field(deploymentMetricsQPSField).asDouble();
- double writesPerSecond = object.field(deploymentMetricsWPSField).asDouble();
- double documentCount = object.field(deploymentMetricsDocsField).asDouble();
- double queryLatencyMillis = object.field(deploymentMetricsQueryLatencyField).asDouble();
- double writeLatencyMills = object.field(deploymentMetricsWriteLatencyField).asDouble();
-
- return new DeploymentMetrics(queriesPerSecond, writesPerSecond,
- documentCount, queryLatencyMillis, writeLatencyMills);
+ return new DeploymentMetrics(object.field(deploymentMetricsQPSField).asDouble(),
+ object.field(deploymentMetricsWPSField).asDouble(),
+ object.field(deploymentMetricsDocsField).asDouble(),
+ object.field(deploymentMetricsQueryLatencyField).asDouble(),
+ object.field(deploymentMetricsWriteLatencyField).asDouble());
}
private Map<ClusterSpec.Id, ClusterInfo> clusterInfoMapFromSlime(Inspector object) {
@@ -426,4 +429,9 @@ public class ApplicationSerializer {
return SlimeUtils.optionalString(field);
}
+ private Optional<Instant> optionalInstant(Inspector field) {
+ OptionalLong value = optionalLong(field);
+ return value.isPresent() ? Optional.of(Instant.ofEpochMilli(value.getAsLong())) : Optional.empty();
+ }
+
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
index 10088ba3fea..3eced6d943e 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
@@ -425,6 +425,11 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
metricsObject.setDouble("queryServiceQuality", application.metrics().queryServiceQuality());
metricsObject.setDouble("writeServiceQuality", application.metrics().writeServiceQuality());
+ // Activity
+ Cursor activity = object.setObject("activity");
+ application.activity().lastQueried().ifPresent(lastQueried -> activity.setLong("queriedAt", lastQueried.toEpochMilli()));
+ application.activity().lastWritten().ifPresent(lastQueried -> activity.setLong("writtenAt", lastQueried.toEpochMilli()));
+
application.ownershipIssueId().ifPresent(issueId -> object.setString("ownershipIssueId", issueId.value()));
application.deploymentJobs().issueId().ifPresent(issueId -> object.setString("deploymentIssueId", issueId.value()));
}
@@ -468,6 +473,12 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
.ifPresent(i -> response.setString("screwdriverId", String.valueOf(i)));
sourceRevisionToSlime(deployment.applicationVersion().source(), response);
+ Cursor activity = response.setObject("activity");
+ deployment.activity().lastQueried().ifPresent(instant -> activity.setLong("lastQueried",
+ instant.toEpochMilli()));
+ deployment.activity().lastWritten().ifPresent(instant -> activity.setLong("lastWritten",
+ instant.toEpochMilli()));
+
// Cost
DeploymentCost appCost = deployment.calculateCost();
Cursor costObject = response.setObject("cost");
@@ -672,7 +683,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
ApplicationId id = ApplicationId.from(tenantName, applicationName, "default");
controller.applications().lockOrThrow(id, application -> {
- controller.applications().deploymentTrigger().triggerChange(application.id(), Change.of(version));
+ controller.applications().deploymentTrigger().triggerChange(application.get().id(), Change.of(version));
});
return new MessageResponse("Triggered deployment of application '" + id + "' on version " + version);
}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java
index 0de153fc3f9..c24c8693688 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java
@@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.controller;
import com.yahoo.component.Version;
import com.yahoo.config.application.api.ValidationId;
+import com.yahoo.config.application.api.ValidationOverrides;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.ApplicationName;
import com.yahoo.config.provision.Environment;
@@ -179,7 +180,9 @@ public class ControllerTest {
fail("Expected exception due to illegal production deployment removal");
}
catch (IllegalArgumentException e) {
- assertEquals("deployment-removal: application 'tenant1.app1' is deployed in corp-us-east-1, but does not include this zone in deployment.xml", e.getMessage());
+ assertEquals("deployment-removal: application 'tenant1.app1' is deployed in corp-us-east-1, but does not include this zone in deployment.xml. " +
+ ValidationOverrides.toAllowMessage(ValidationId.deploymentRemoval),
+ e.getMessage());
}
assertNotNull("Zone was not removed",
applications.require(app1.id()).deployments().get(productionCorpUsEast1.zone(main).get()));
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java
index 98189613bd0..cf2fa182d0a 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java
@@ -13,7 +13,6 @@ import com.yahoo.vespa.curator.mock.MockCurator;
import com.yahoo.vespa.hosted.controller.api.application.v4.model.DeployOptions;
import com.yahoo.vespa.hosted.controller.api.identifiers.Property;
import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId;
-import com.yahoo.vespa.hosted.controller.api.identifiers.ScrewdriverId;
import com.yahoo.vespa.hosted.controller.api.integration.BuildService;
import com.yahoo.vespa.hosted.controller.api.integration.chef.ChefMock;
import com.yahoo.vespa.hosted.controller.api.integration.deployment.ArtifactRepository;
@@ -41,6 +40,7 @@ import com.yahoo.vespa.hosted.rotation.config.RotationsConfig;
import java.util.Optional;
import java.util.OptionalLong;
+import java.util.function.Supplier;
import java.util.logging.Logger;
import static org.junit.Assert.assertNotNull;
@@ -64,25 +64,28 @@ public final class ControllerTester {
private final ArtifactRepositoryMock artifactRepository;
private final EntityService entityService;
private final MockBuildService buildService;
+ private final MockMetricsService metricsService;
private Controller controller;
- public ControllerTester(ManualClock clock, RotationsConfig rotationsConfig, MockCuratorDb curatorDb) {
+ public ControllerTester(ManualClock clock, RotationsConfig rotationsConfig, MockCuratorDb curatorDb,
+ MockMetricsService metricsService) {
this(new AthenzDbMock(), clock, new ConfigServerMock(new ZoneRegistryMock()),
new ZoneRegistryMock(), new GitHubMock(), curatorDb, rotationsConfig,
- new MemoryNameService(), new ArtifactRepositoryMock(), new MemoryEntityService(), new MockBuildService());
+ new MemoryNameService(), new ArtifactRepositoryMock(), new MemoryEntityService(), new MockBuildService(),
+ metricsService);
}
public ControllerTester(ManualClock clock) {
- this(clock, defaultRotationsConfig(), new MockCuratorDb());
+ this(clock, defaultRotationsConfig(), new MockCuratorDb(), new MockMetricsService());
}
public ControllerTester(RotationsConfig rotationsConfig) {
- this(new ManualClock(), rotationsConfig, new MockCuratorDb());
+ this(new ManualClock(), rotationsConfig, new MockCuratorDb(), new MockMetricsService());
}
public ControllerTester(MockCuratorDb curatorDb) {
- this(new ManualClock(), defaultRotationsConfig(), curatorDb);
+ this(new ManualClock(), defaultRotationsConfig(), curatorDb, new MockMetricsService());
}
public ControllerTester() {
@@ -93,7 +96,8 @@ public final class ControllerTester {
ConfigServerMock configServer, ZoneRegistryMock zoneRegistry,
GitHubMock gitHub, CuratorDb curator, RotationsConfig rotationsConfig,
MemoryNameService nameService, ArtifactRepositoryMock artifactRepository,
- EntityService entityService, MockBuildService buildService) {
+ EntityService entityService, MockBuildService buildService,
+ MockMetricsService metricsService) {
this.athenzDb = athenzDb;
this.clock = clock;
this.configServer = configServer;
@@ -105,8 +109,10 @@ public final class ControllerTester {
this.artifactRepository = artifactRepository;
this.entityService = entityService;
this.buildService = buildService;
+ this.metricsService = metricsService;
this.controller = createController(curator, rotationsConfig, configServer, clock, gitHub, zoneRegistry,
- athenzDb, nameService, artifactRepository, entityService, buildService);
+ athenzDb, nameService, artifactRepository, entityService, buildService,
+ metricsService);
// Make root logger use time from manual clock
Logger.getLogger("").getHandlers()[0].setFilter(
@@ -138,10 +144,12 @@ public final class ControllerTester {
public MockBuildService buildService() { return buildService; }
+ public MockMetricsService metricsService() { return metricsService; }
+
/** Create a new controller instance. Useful to verify that controller state is rebuilt from persistence */
public final void createNewController() {
controller = createController(curator, rotationsConfig, configServer, clock, gitHub, zoneRegistry, athenzDb,
- nameService, artifactRepository, entityService, buildService);
+ nameService, artifactRepository, entityService, buildService, metricsService);
}
/** Creates the given tenant and application and deploys it */
@@ -241,6 +249,10 @@ public final class ControllerTester {
new DeployOptions(false, Optional.empty(), false, deployCurrentVersion));
}
+ public Supplier<Application> application(ApplicationId application) {
+ return () -> controller().applications().require(application);
+ }
+
/** Used by ApplicationSerializerTest to avoid breaking encapsulation. Should not be used by anything else */
public static LockedApplication writable(Application application) {
return new LockedApplication(application, new Lock("/test", new MockCurator()));
@@ -251,7 +263,7 @@ public final class ControllerTester {
GitHubMock gitHub, ZoneRegistryMock zoneRegistryMock,
AthenzDbMock athensDb, MemoryNameService nameService,
ArtifactRepository artifactRepository, EntityService entityService,
- BuildService buildService) {
+ BuildService buildService, MockMetricsService metricsService) {
Controller controller = new Controller(curator,
rotationsConfig,
gitHub,
@@ -260,7 +272,7 @@ public final class ControllerTester {
new MemoryGlobalRoutingService(),
zoneRegistryMock,
configServer,
- new MockMetricsService(),
+ metricsService,
nameService,
new MockRoutingGenerator(),
new ChefMock(),
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/MockMetricsService.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/MockMetricsService.java
index 88bbb582564..67a4139ecf1 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/MockMetricsService.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/MockMetricsService.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.hosted.controller.integration;
import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.vespa.hosted.controller.api.integration.MetricsService;
import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId;
import java.util.HashMap;
@@ -10,16 +11,28 @@ import java.util.Map;
/**
* @author bratseth
*/
-public class MockMetricsService implements com.yahoo.vespa.hosted.controller.api.integration.MetricsService {
+public class MockMetricsService implements MetricsService {
+
+ private final Map<String, Double> metrics = new HashMap<>();
+
+ public MockMetricsService setMetric(String key, Double value) {
+ metrics.put(key, value);
+ return this;
+ }
@Override
public ApplicationMetrics getApplicationMetrics(ApplicationId application) {
- return new ApplicationMetrics(0.5, 0.7);
+ return new ApplicationMetrics(metrics.getOrDefault("queryServiceQuality", 0.5),
+ metrics.getOrDefault("writeServiceQuality", 0.7));
}
@Override
public DeploymentMetrics getDeploymentMetrics(ApplicationId application, ZoneId zone) {
- return new DeploymentMetrics(1, 2, 3, 4, 5);
+ return new DeploymentMetrics(metrics.getOrDefault("queriesPerSecond", 1D),
+ metrics.getOrDefault("writesPerSecond", 2D),
+ metrics.getOrDefault("docoumentCount", 3D).longValue(),
+ metrics.getOrDefault("queryLatencyMillis", 4D),
+ metrics.getOrDefault("writeLatencyMillis", 5D));
}
@Override
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainerTest.java
index 148d11e8b38..a651210767d 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainerTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainerTest.java
@@ -12,37 +12,69 @@ import com.yahoo.vespa.hosted.controller.persistence.MockCuratorDb;
import org.junit.Test;
import java.time.Duration;
+import java.time.Instant;
+import java.util.function.Supplier;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
/**
* @author smorgrav
+ * @author mpolden
*/
public class DeploymentMetricsMaintainerTest {
@Test
public void maintain() {
ControllerTester tester = new ControllerTester();
- ApplicationId app = tester.createAndDeploy("tenant1", "domain1", "app1", Environment.dev, 123).id();
+ ApplicationId appId = tester.createAndDeploy("tenant1", "domain1", "app1",
+ Environment.dev, 123).id();
+ DeploymentMetricsMaintainer maintainer = new DeploymentMetricsMaintainer(tester.controller(),
+ Duration.ofDays(1),
+ new JobControl(new MockCuratorDb()));
+ Supplier<Application> app = tester.application(appId);
+ Supplier<Deployment> deployment = () -> app.get().deployments().values().stream().findFirst().get();
- // Pre condition: no metric info on neither application nor deployment
- assertEquals(0, tester.controller().applications().require(app).metrics().queryServiceQuality(), 0);
- Deployment deployment = tester.controller().applications().get(app).get().deployments().values().stream().findAny().get();
- assertEquals(0, deployment.metrics().documentCount(), 0);
+ // No metrics gathered yet
+ assertEquals(0, app.get().metrics().queryServiceQuality(), 0);
+ assertEquals(0, deployment.get().metrics().documentCount(), 0);
+ assertFalse("Never received any queries", deployment.get().activity().lastQueried().isPresent());
+ assertFalse("Never received any writes", deployment.get().activity().lastWritten().isPresent());
- DeploymentMetricsMaintainer maintainer = new DeploymentMetricsMaintainer(tester.controller(), Duration.ofMinutes(10), new JobControl(new MockCuratorDb()));
+ // Metrics are gathered and saved to application
maintainer.maintain();
+ assertEquals(0.5, app.get().metrics().queryServiceQuality(), Double.MIN_VALUE);
+ assertEquals(0.7, app.get().metrics().writeServiceQuality(), Double.MIN_VALUE);
+ assertEquals(1, deployment.get().metrics().queriesPerSecond(), Double.MIN_VALUE);
+ assertEquals(2, deployment.get().metrics().writesPerSecond(), Double.MIN_VALUE);
+ assertEquals(3, deployment.get().metrics().documentCount(), Double.MIN_VALUE);
+ assertEquals(4, deployment.get().metrics().queryLatencyMillis(), Double.MIN_VALUE);
+ assertEquals(5, deployment.get().metrics().writeLatencyMillis(), Double.MIN_VALUE);
+ Instant t1 = tester.clock().instant();
+ assertEquals(t1, deployment.get().activity().lastQueried().get());
+ assertEquals(t1, deployment.get().activity().lastWritten().get());
- // Post condition:
- Application application = tester.controller().applications().require(app);
- assertEquals(0.5, application.metrics().queryServiceQuality(), Double.MIN_VALUE);
- assertEquals(0.7, application.metrics().writeServiceQuality(), Double.MIN_VALUE);
- deployment = application.deployments().values().stream().findAny().get();
- assertEquals(1, deployment.metrics().queriesPerSecond(), Double.MIN_VALUE);
- assertEquals(2, deployment.metrics().writesPerSecond(), Double.MIN_VALUE);
- assertEquals(3, deployment.metrics().documentCount(), Double.MIN_VALUE);
- assertEquals(4, deployment.metrics().queryLatencyMillis(), Double.MIN_VALUE);
- assertEquals(5, deployment.metrics().writeLatencyMillis(), Double.MIN_VALUE);
+ // Time passes. Activity is updated as app is still receiving traffic
+ tester.clock().advance(Duration.ofHours(1));
+ Instant t2 = tester.clock().instant();
+ maintainer.maintain();
+ assertEquals(t2, deployment.get().activity().lastQueried().get());
+ assertEquals(t2, deployment.get().activity().lastWritten().get());
+
+ // Query traffic disappears. Query activity time is no longer updated
+ tester.clock().advance(Duration.ofHours(1));
+ Instant t3 = tester.clock().instant();
+ tester.metricsService().setMetric("queriesPerSecond", 0D);
+ maintainer.maintain();
+ assertEquals(t2, deployment.get().activity().lastQueried().get());
+ assertEquals(t3, deployment.get().activity().lastWritten().get());
+
+ // Feed traffic disappears. Feed activity time is no longer updated
+ tester.clock().advance(Duration.ofHours(1));
+ tester.metricsService().setMetric("writesPerSecond", 0D);
+ maintainer.maintain();
+ assertEquals(t2, deployment.get().activity().lastQueried().get());
+ assertEquals(t3, deployment.get().activity().lastWritten().get());
}
}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java
index f6bf3bdd8cf..5c5827fa167 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java
@@ -9,7 +9,6 @@ import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.slime.Slime;
import com.yahoo.vespa.config.SlimeUtils;
import com.yahoo.vespa.hosted.controller.Application;
-import com.yahoo.vespa.hosted.controller.ControllerTester;
import com.yahoo.vespa.hosted.controller.api.integration.MetricsService;
import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueId;
import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId;
@@ -18,6 +17,7 @@ import com.yahoo.vespa.hosted.controller.application.Change;
import com.yahoo.vespa.hosted.controller.application.ClusterInfo;
import com.yahoo.vespa.hosted.controller.application.ClusterUtilization;
import com.yahoo.vespa.hosted.controller.application.Deployment;
+import com.yahoo.vespa.hosted.controller.application.DeploymentActivity;
import com.yahoo.vespa.hosted.controller.application.DeploymentJobs;
import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobError;
import com.yahoo.vespa.hosted.controller.application.DeploymentMetrics;
@@ -42,7 +42,6 @@ import static com.yahoo.config.provision.SystemName.main;
import static com.yahoo.vespa.hosted.controller.ControllerTester.writable;
import static java.util.Optional.empty;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
/**
* @author bratseth
@@ -56,7 +55,6 @@ public class ApplicationSerializerTest {
@Test
public void testSerialization() {
- ControllerTester tester = new ControllerTester();
DeploymentSpec deploymentSpec = DeploymentSpec.fromXml("<deployment version='1.0'>" +
" <staging/>" +
"</deployment>");
@@ -68,9 +66,12 @@ public class ApplicationSerializerTest {
ApplicationVersion applicationVersion1 = ApplicationVersion.from(new SourceRevision("repo1", "branch1", "commit1"), 31);
ApplicationVersion applicationVersion2 = ApplicationVersion
.from(new SourceRevision("repo1", "branch1", "commit1"), 32);
+ Instant activityAt = Instant.parse("2018-06-01T10:15:30.00Z");
deployments.add(new Deployment(zone1, applicationVersion1, Version.fromString("1.2.3"), Instant.ofEpochMilli(3))); // One deployment without cluster info and utils
deployments.add(new Deployment(zone2, applicationVersion2, Version.fromString("1.2.3"), Instant.ofEpochMilli(5),
- createClusterUtils(3, 0.2), createClusterInfo(3, 4),new DeploymentMetrics(2,3,4,5,6)));
+ createClusterUtils(3, 0.2), createClusterInfo(3, 4),
+ new DeploymentMetrics(2,3,4,5,6),
+ DeploymentActivity.create(Optional.of(activityAt), Optional.of(activityAt))));
OptionalLong projectId = OptionalLong.of(123L);
List<JobStatus> statusList = new ArrayList<>();
@@ -111,6 +112,8 @@ public class ApplicationSerializerTest {
assertEquals(original.deployments().get(zone2).version(), serialized.deployments().get(zone2).version());
assertEquals(original.deployments().get(zone1).at(), serialized.deployments().get(zone1).at());
assertEquals(original.deployments().get(zone2).at(), serialized.deployments().get(zone2).at());
+ assertEquals(original.deployments().get(zone2).activity().lastQueried().get(), serialized.deployments().get(zone2).activity().lastQueried().get());
+ assertEquals(original.deployments().get(zone2).activity().lastWritten().get(), serialized.deployments().get(zone2).activity().lastWritten().get());
assertEquals(original.deploymentJobs().projectId(), serialized.deploymentJobs().projectId());
assertEquals(original.deploymentJobs().jobStatus().size(), serialized.deploymentJobs().jobStatus().size());
@@ -146,34 +149,33 @@ public class ApplicationSerializerTest {
// Test metrics
assertEquals(original.metrics().queryServiceQuality(), serialized.metrics().queryServiceQuality(), Double.MIN_VALUE);
assertEquals(original.metrics().writeServiceQuality(), serialized.metrics().writeServiceQuality(), Double.MIN_VALUE);
-
- assertEquals(2, serialized.deployments().get(zone2).metrics().queriesPerSecond(), Double.MIN_VALUE);
- assertEquals(3, serialized.deployments().get(zone2).metrics().writesPerSecond(), Double.MIN_VALUE);
- assertEquals(4, serialized.deployments().get(zone2).metrics().documentCount(), Double.MIN_VALUE);
- assertEquals(5, serialized.deployments().get(zone2).metrics().queryLatencyMillis(), Double.MIN_VALUE);
- assertEquals(6, serialized.deployments().get(zone2).metrics().writeLatencyMillis(), Double.MIN_VALUE);
+ assertEquals(original.deployments().get(zone2).metrics().queriesPerSecond(), serialized.deployments().get(zone2).metrics().queriesPerSecond(), Double.MIN_VALUE);
+ assertEquals(original.deployments().get(zone2).metrics().writesPerSecond(), serialized.deployments().get(zone2).metrics().writesPerSecond(), Double.MIN_VALUE);
+ assertEquals(original.deployments().get(zone2).metrics().documentCount(), serialized.deployments().get(zone2).metrics().documentCount(), Double.MIN_VALUE);
+ assertEquals(original.deployments().get(zone2).metrics().queryLatencyMillis(), serialized.deployments().get(zone2).metrics().queryLatencyMillis(), Double.MIN_VALUE);
+ assertEquals(original.deployments().get(zone2).metrics().writeLatencyMillis(), serialized.deployments().get(zone2).metrics().writeLatencyMillis(), Double.MIN_VALUE);
{ // test more deployment serialization cases
- Application original2 = writable(original).withChange(Change.of(ApplicationVersion.from(new SourceRevision("repo1", "branch1", "commit1"), 42)));
+ Application original2 = writable(original).withChange(Change.of(ApplicationVersion.from(new SourceRevision("repo1", "branch1", "commit1"), 42))).get();
Application serialized2 = applicationSerializer.fromSlime(applicationSerializer.toSlime(original2));
assertEquals(original2.change(), serialized2.change());
assertEquals(serialized2.change().application().get().source(),
original2.change().application().get().source());
- Application original3 = writable(original).withChange(Change.of(ApplicationVersion.from(new SourceRevision("a", "b", "c"), 42)));
+ Application original3 = writable(original).withChange(Change.of(ApplicationVersion.from(new SourceRevision("a", "b", "c"), 42))).get();
Application serialized3 = applicationSerializer.fromSlime(applicationSerializer.toSlime(original3));
assertEquals(original3.change(), serialized3.change());
assertEquals(serialized3.change().application().get().source(),
original3.change().application().get().source());
- Application original4 = writable(original).withChange(Change.empty());
+ Application original4 = writable(original).withChange(Change.empty()).get();
Application serialized4 = applicationSerializer.fromSlime(applicationSerializer.toSlime(original4));
assertEquals(original4.change(), serialized4.change());
- Application original5 = writable(original).withChange(Change.of(ApplicationVersion.from(new SourceRevision("a", "b", "c"), 42)));
+ Application original5 = writable(original).withChange(Change.of(ApplicationVersion.from(new SourceRevision("a", "b", "c"), 42))).get();
Application serialized5 = applicationSerializer.fromSlime(applicationSerializer.toSlime(original5));
assertEquals(original5.change(), serialized5.change());
- Application original6 = writable(original).withOutstandingChange(Change.of(ApplicationVersion.from(new SourceRevision("a", "b", "c"), 42)));
+ Application original6 = writable(original).withOutstandingChange(Change.of(ApplicationVersion.from(new SourceRevision("a", "b", "c"), 42))).get();
Application serialized6 = applicationSerializer.fromSlime(applicationSerializer.toSlime(original6));
assertEquals(original6.outstandingChange(), serialized6.outstandingChange());
}
@@ -210,15 +212,6 @@ public class ApplicationSerializerTest {
}
@Test
- public void testLegacySerialization() {
- Application applicationWithSuccessfulJob = applicationSerializer.fromSlime(applicationSlime(false));
- assertFalse("No job error for successful job", applicationWithSuccessfulJob.deploymentJobs().jobStatus().get(DeploymentJobs.JobType.systemTest).jobError().isPresent());
-
- Application applicationWithFailingJob = applicationSerializer.fromSlime(applicationSlime(true));
- assertEquals(JobError.unknown, applicationWithFailingJob.deploymentJobs().jobStatus().get(DeploymentJobs.JobType.systemTest).jobError().get());
- }
-
- @Test
public void testCompleteApplicationDeserialization() throws Exception {
byte[] applicationJson = Files.readAllBytes(testData.resolve("complete-application.json"));
applicationSerializer.fromSlime(SlimeUtils.jsonToSlime(applicationJson));
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java
index 8d734ec549c..545ee529635 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java
@@ -57,6 +57,7 @@ import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
+import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -1123,7 +1124,8 @@ public class ApplicationApiTest extends ControllerContainerTest {
lockedApplication = lockedApplication
.withClusterInfo(deployment.zone(), clusterInfo)
.withClusterUtilization(deployment.zone(), clusterUtils)
- .with(deployment.zone(), metrics);
+ .with(deployment.zone(), metrics)
+ .recordActivityAt(Instant.parse("2018-06-01T10:15:30.00Z"), deployment.zone());
}
controllerTester.controller().applications().store(lockedApplication);
});
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json
index 30070e509c7..f8c4c26d6a8 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json
@@ -249,5 +249,9 @@
"metrics": {
"queryServiceQuality": 0.5,
"writeServiceQuality": 0.7
+ },
+ "activity": {
+ "queriedAt": 1527848130000,
+ "writtenAt": 1527848130000
}
}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json
index dabeb3239aa..fc0f83c2cdc 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json
@@ -234,5 +234,9 @@
"metrics": {
"queryServiceQuality": 0.5,
"writeServiceQuality": 0.7
+ },
+ "activity": {
+ "queriedAt": 1527848130000,
+ "writtenAt": 1527848130000
}
}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json
index 174bb2f1ba7..8bb1ee83282 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json
@@ -222,6 +222,10 @@
"queryServiceQuality": 0.5,
"writeServiceQuality": 0.7
},
+ "activity": {
+ "queriedAt": 1527848130000,
+ "writtenAt": 1527848130000
+ },
"ownershipIssueId": "321",
"deploymentIssueId": "123"
}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json
index 9174e7dd8b2..79e86b5f7f4 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json
@@ -16,6 +16,10 @@
"gitRepository": "repository1",
"gitBranch": "master",
"gitCommit": "commit1",
+ "activity": {
+ "lastQueried": 1527848130000,
+ "lastWritten": 1527848130000
+ },
"cost": {
"tco": 74,
"waste": 0,
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json
index d13a4dac116..8fccd738554 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json
@@ -15,8 +15,10 @@
"revision": "(ignore)",
"deployTimeEpochMs": "(ignore)",
"screwdriverId": "123",
-
-
+ "activity": {
+ "lastQueried": 1527848130000,
+ "lastWritten": 1527848130000
+ },
"cost": {
"tco": 74,
"waste": 0,
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-corp-us-east-1.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-corp-us-east-1.json
index 0f16bee308d..066e840fe16 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-corp-us-east-1.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-corp-us-east-1.json
@@ -22,6 +22,10 @@
"gitRepository": "repository1",
"gitBranch": "master",
"gitCommit": "commit1",
+ "activity": {
+ "lastQueried": 1527848130000,
+ "lastWritten": 1527848130000
+ },
"cost": {
"tco": 74,
"waste": 0,
diff --git a/docprocs/src/test/java/com/yahoo/docprocs/indexing/DocumentScriptTestCase.java b/docprocs/src/test/java/com/yahoo/docprocs/indexing/DocumentScriptTestCase.java
index 5b1a4412b41..419b60432c4 100644
--- a/docprocs/src/test/java/com/yahoo/docprocs/indexing/DocumentScriptTestCase.java
+++ b/docprocs/src/test/java/com/yahoo/docprocs/indexing/DocumentScriptTestCase.java
@@ -1,11 +1,13 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.docprocs.indexing;
+import com.yahoo.document.ArrayDataType;
import com.yahoo.document.DataType;
import com.yahoo.document.Document;
import com.yahoo.document.DocumentType;
import com.yahoo.document.DocumentUpdate;
import com.yahoo.document.Field;
+import com.yahoo.document.MapDataType;
import com.yahoo.document.StructDataType;
import com.yahoo.document.annotation.SpanTree;
import com.yahoo.document.annotation.SpanTrees;
@@ -16,6 +18,7 @@ import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.document.datatypes.Struct;
import com.yahoo.document.datatypes.WeightedSet;
import com.yahoo.document.fieldpathupdate.AssignFieldPathUpdate;
+import com.yahoo.document.fieldpathupdate.FieldPathUpdate;
import com.yahoo.document.update.FieldUpdate;
import com.yahoo.document.update.MapValueUpdate;
import com.yahoo.document.update.ValueUpdate;
@@ -30,6 +33,7 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@@ -157,6 +161,82 @@ public class DocumentScriptTestCase {
assertSpanTrees(str, "mySpanTree");
}
+ private class FieldPathFixture {
+ final DocumentType type;
+ final StructDataType structType;
+ final DataType structMap;
+ final DataType structArray;
+
+ FieldPathFixture() {
+ type = newDocumentType();
+ structType = new StructDataType("mystruct");
+ structType.addField(new Field("title", DataType.STRING));
+ structType.addField(new Field("rating", DataType.INT));
+ structArray = new ArrayDataType(structType);
+ type.addField(new Field("structarray", structArray));
+ structMap = new MapDataType(DataType.STRING, structType);
+ type.addField(new Field("structmap", structMap));
+ type.addField(new Field("structfield", structType));
+ }
+
+ DocumentUpdate executeWithUpdate(String fieldName, FieldPathUpdate updateIn) {
+ DocumentUpdate update = new DocumentUpdate(type, "doc:scheme:");
+ update.addFieldPathUpdate(updateIn);
+ return newScript(type, fieldName).execute(ADAPTER_FACTORY, update);
+ }
+
+ FieldPathUpdate executeWithUpdateAndExpectFieldPath(String fieldName, FieldPathUpdate updateIn) {
+ DocumentUpdate update = executeWithUpdate(fieldName, updateIn);
+ assertEquals(1, update.getFieldPathUpdates().size());
+ return update.getFieldPathUpdates().get(0);
+ }
+ }
+
+ @Test
+ public void array_field_path_updates_survive_indexing_scripts() {
+ FieldPathFixture f = new FieldPathFixture();
+
+ Struct newElemValue = new Struct(f.structType);
+ newElemValue.setFieldValue("title", "iron moose 2, the moosening");
+
+ FieldPathUpdate updated = f.executeWithUpdateAndExpectFieldPath("structarray", new AssignFieldPathUpdate(f.type, "structarray[10]", newElemValue));
+
+ assertTrue(updated instanceof AssignFieldPathUpdate);
+ AssignFieldPathUpdate assignUpdate = (AssignFieldPathUpdate)updated;
+ assertEquals("structarray[10]", assignUpdate.getOriginalFieldPath());
+ assertEquals(newElemValue, assignUpdate.getFieldValue());
+ }
+
+ @Test
+ public void map_field_path_updates_survive_indexing_scripts() {
+ FieldPathFixture f = new FieldPathFixture();
+
+ Struct newElemValue = new Struct(f.structType);
+ newElemValue.setFieldValue("title", "iron moose 3, moose in new york");
+
+ FieldPathUpdate updated = f.executeWithUpdateAndExpectFieldPath("structmap", new AssignFieldPathUpdate(f.type, "structmap{foo}", newElemValue));
+
+ assertTrue(updated instanceof AssignFieldPathUpdate);
+ AssignFieldPathUpdate assignUpdate = (AssignFieldPathUpdate)updated;
+ assertEquals("structmap{foo}", assignUpdate.getOriginalFieldPath());
+ assertEquals(newElemValue, assignUpdate.getFieldValue());
+ }
+
+ @Test
+ public void nested_struct_fieldpath_update_is_not_converted_to_regular_field_value_update() {
+ FieldPathFixture f = new FieldPathFixture();
+
+ StringFieldValue newTitleValue = new StringFieldValue("iron moose 4, moose with a vengeance");
+ DocumentUpdate update = f.executeWithUpdate("structfield", new AssignFieldPathUpdate(f.type, "structfield.title", newTitleValue));
+
+ assertEquals(1, update.getFieldPathUpdates().size());
+ assertEquals(0, update.getFieldUpdates().size());
+ assertTrue(update.getFieldPathUpdates().get(0) instanceof AssignFieldPathUpdate);
+ AssignFieldPathUpdate assignUpdate = (AssignFieldPathUpdate)update.getFieldPathUpdates().get(0);
+ assertEquals("structfield.title", assignUpdate.getOriginalFieldPath());
+ assertEquals(newTitleValue, assignUpdate.getFieldValue());
+ }
+
private static FieldValue processDocument(FieldValue fieldValue) {
DocumentType docType = new DocumentType("myDocumentType");
docType.addField("myField", fieldValue.getDataType());
@@ -184,11 +264,15 @@ public class DocumentScriptTestCase {
return update.getFieldUpdate("myField").getValueUpdate(0);
}
+ private static DocumentScript newScript(DocumentType docType, String fieldName) {
+ return new DocumentScript(docType.getName(), Collections.singletonList(fieldName),
+ new StatementExpression(new InputExpression(fieldName),
+ new IndexExpression(fieldName)));
+ }
+
private static DocumentScript newScript(DocumentType docType) {
String fieldName = docType.getFields().iterator().next().getName();
- return new DocumentScript(docType.getName(), Arrays.asList(fieldName),
- new StatementExpression(new InputExpression(fieldName),
- new IndexExpression(fieldName)));
+ return newScript(docType, fieldName);
}
private static StringFieldValue newString(String... spanTrees) {
@@ -210,6 +294,7 @@ public class DocumentScriptTestCase {
DocumentType type = new DocumentType("documentType");
type.addField("documentField", DataType.STRING);
type.addField("extraField", DataType.STRING);
+
return type;
}
diff --git a/document/src/main/java/com/yahoo/document/datatypes/Array.java b/document/src/main/java/com/yahoo/document/datatypes/Array.java
index e37a32f28f4..01326bcea62 100644
--- a/document/src/main/java/com/yahoo/document/datatypes/Array.java
+++ b/document/src/main/java/com/yahoo/document/datatypes/Array.java
@@ -290,7 +290,8 @@ public final class Array<T extends FieldValue> extends CollectionFieldValue<T> i
if (pos < fieldPath.size()) {
switch (fieldPath.get(pos).getType()) {
case ARRAY_INDEX:
- return iterateSubset(fieldPath.get(pos).getLookupIndex(), fieldPath.get(pos).getLookupIndex(), fieldPath, null, pos + 1, handler);
+ final int elemIndex = fieldPath.get(pos).getLookupIndex();
+ return iterateSubset(elemIndex, elemIndex, fieldPath, null, pos + 1, handler);
case VARIABLE: {
FieldPathIteratorHandler.IndexValue val = handler.getVariables().get(fieldPath.get(pos).getVariableName());
if (val != null) {
diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml
index 1415ca6e5aa..0011d108b98 100644
--- a/fat-model-dependencies/pom.xml
+++ b/fat-model-dependencies/pom.xml
@@ -16,13 +16,6 @@
<groupId>com.yahoo.vespa</groupId>
<artifactId>config-model</artifactId>
<version>${project.version}</version>
- <exclusions>
- <exclusion>
- <!-- Large, and installed separately as part of Vespa -->
- <groupId>org.tensorflow</groupId>
- <artifactId>libtensorflow_jni</artifactId>
- </exclusion>
- </exclusions>
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldPathUpdateHelper.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldPathUpdateHelper.java
index 171c6a8eb9a..5c170fe147e 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldPathUpdateHelper.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldPathUpdateHelper.java
@@ -20,19 +20,10 @@ public abstract class FieldPathUpdateHelper {
if (!(update instanceof AssignFieldPathUpdate)) {
return false;
}
- for (FieldPathEntry entry : update.getFieldPath()) {
- switch (entry.getType()) {
- case STRUCT_FIELD:
- case MAP_ALL_KEYS:
- case MAP_ALL_VALUES:
- continue;
- case ARRAY_INDEX:
- case MAP_KEY:
- case VARIABLE:
- return false;
- }
- }
- return true;
+ // Only consider field path updates that touch a top-level field as 'complete',
+ // as these may be converted to regular field value updates.
+ return ((update.getFieldPath().size() == 1)
+ && update.getFieldPath().get(0).getType() == FieldPathEntry.Type.STRUCT_FIELD);
}
public static void applyUpdate(FieldPathUpdate update, Document doc) {
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/IdentityFieldPathUpdateAdapter.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/IdentityFieldPathUpdateAdapter.java
new file mode 100644
index 00000000000..42c9bd8c10c
--- /dev/null
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/IdentityFieldPathUpdateAdapter.java
@@ -0,0 +1,68 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.indexinglanguage;
+
+import com.yahoo.document.DataType;
+import com.yahoo.document.Document;
+import com.yahoo.document.DocumentUpdate;
+import com.yahoo.document.FieldPath;
+import com.yahoo.document.datatypes.FieldValue;
+import com.yahoo.document.fieldpathupdate.FieldPathUpdate;
+import com.yahoo.vespa.indexinglanguage.expressions.Expression;
+import com.yahoo.vespa.indexinglanguage.expressions.FieldValueAdapter;
+
+/**
+ * No-op update adapter which simply passes through the input update unchanged.
+ * I.e. getOutput() will return a DocumentUpdate containing only the FieldPathUpdate
+ * the IdentityFieldPathUpdateAdapter was created with. All other applicable calls are
+ * forwarded to the provided DocumentAdapter instance.
+ *
+ * This removes the need for a potentially lossy round-trip of update -&gt; synthetic document -&gt; update.
+ */
+public class IdentityFieldPathUpdateAdapter implements UpdateAdapter {
+
+ private final FieldPathUpdate update;
+ private final DocumentAdapter fwdAdapter;
+
+ public IdentityFieldPathUpdateAdapter(FieldPathUpdate update, DocumentAdapter fwdAdapter) {
+ this.update = update;
+ this.fwdAdapter = fwdAdapter;
+ }
+
+ @Override
+ public DocumentUpdate getOutput() {
+ Document doc = fwdAdapter.getFullOutput();
+ DocumentUpdate upd = new DocumentUpdate(doc.getDataType(), doc.getId());
+ upd.addFieldPathUpdate(update);
+ return upd;
+ }
+
+ @Override
+ public Expression getExpression(Expression expression) {
+ return expression;
+ }
+
+ @Override
+ public FieldValue getInputValue(String fieldName) {
+ return fwdAdapter.getInputValue(fieldName);
+ }
+
+ @Override
+ public FieldValue getInputValue(FieldPath fieldPath) {
+ return fwdAdapter.getInputValue(fieldPath);
+ }
+
+ @Override
+ public FieldValueAdapter setOutputValue(Expression exp, String fieldName, FieldValue fieldValue) {
+ return fwdAdapter.setOutputValue(exp, fieldName, fieldValue);
+ }
+
+ @Override
+ public DataType getInputType(Expression exp, String fieldName) {
+ return fwdAdapter.getInputType(exp, fieldName);
+ }
+
+ @Override
+ public void tryOutputType(Expression exp, String fieldName, DataType valueType) {
+ fwdAdapter.tryOutputType(exp, fieldName, valueType);
+ }
+}
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/SimpleAdapterFactory.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/SimpleAdapterFactory.java
index 2ad09dfbdc4..509bdcaa32d 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/SimpleAdapterFactory.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/SimpleAdapterFactory.java
@@ -49,10 +49,12 @@ public class SimpleAdapterFactory implements AdapterFactory {
Document complete = new Document(docType, upd.getId());
for (FieldPathUpdate fieldUpd : upd) {
if (FieldPathUpdateHelper.isComplete(fieldUpd)) {
+ // A 'complete' field path update is basically a regular top-level field update
+ // in wolf's clothing. Convert it to a regular field update to be friendlier
+ // towards the search core backend.
FieldPathUpdateHelper.applyUpdate(fieldUpd, complete);
} else {
- Document partial = FieldPathUpdateHelper.newPartialDocument(docId, fieldUpd);
- ret.add(new FieldPathUpdateAdapter(newDocumentAdapter(partial, true), fieldUpd));
+ ret.add(new IdentityFieldPathUpdateAdapter(fieldUpd, newDocumentAdapter(complete, true)));
}
}
for (FieldUpdate fieldUpd : upd.getFieldUpdates()) {
diff --git a/jdisc_http_service/pom.xml b/jdisc_http_service/pom.xml
index 6373189e738..f41994c4916 100644
--- a/jdisc_http_service/pom.xml
+++ b/jdisc_http_service/pom.xml
@@ -175,7 +175,6 @@
<extensions>true</extensions>
<configuration>
<discPreInstallBundle>
- asm-debug-all-${asm-debug-all.version}.jar,
bcpkix-jdk15on-${bouncycastle.version}.jar,
bcprov-jdk15on-${bouncycastle.version}.jar,
javax.servlet-api-3.1.0.jar,
@@ -188,8 +187,6 @@
jetty-servlet-${jetty.version}.jar,
jetty-servlets-${jetty.version}.jar,
jetty-util-${jetty.version}.jar,
- org.apache.aries.spifly.dynamic.bundle-${aries.spifly.version}.jar,
- org.apache.aries.util-${aries.util.version}.jar,
component-jar-with-dependencies.jar
</discPreInstallBundle>
</configuration>
diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerConformanceTest.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerConformanceTest.java
index 80c1cb8b458..77411fc080e 100644
--- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerConformanceTest.java
+++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerConformanceTest.java
@@ -323,7 +323,7 @@ public class HttpServerConformanceTest extends ServerProviderConformanceTest {
@Override
@Test
public void testRequestContentWriteExceptionAfterResponseWriteWithSyncCompletion() throws Throwable {
- new TestRunner().expect(success())
+ new TestRunner().expect(anyOf(success(), successNoContent()))
.execute();
}
diff --git a/jdisc_jetty/pom.xml b/jdisc_jetty/pom.xml
index 0f8a5ba19e2..404476f7bf2 100644
--- a/jdisc_jetty/pom.xml
+++ b/jdisc_jetty/pom.xml
@@ -16,10 +16,6 @@
<packaging>jar</packaging>
<dependencies>
<dependency>
- <groupId>org.apache.aries.spifly</groupId>
- <artifactId>org.apache.aries.spifly.dynamic.bundle</artifactId>
- </dependency>
- <dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-continuation</artifactId>
</dependency>
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java
index 7f2d1f1eff7..a7bf22591d4 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java
@@ -276,8 +276,8 @@ public class StorageMaintainer {
*/
public void handleCoreDumpsForContainer(ContainerName containerName, NodeSpec node, boolean force) {
// Sample number of coredumps on the host
- try {
- numberOfCoredumpsOnHost.sample(Files.list(environment.pathInNodeAdminToDoneCoredumps()).count());
+ try (Stream<Path> files = Files.list(environment.pathInNodeAdminToDoneCoredumps())) {
+ numberOfCoredumpsOnHost.sample(files.count());
} catch (IOException e) {
// Ignore for now - this is either test or a misconfiguration
}
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java
index f7e9c3ca1d8..ff85c49bb13 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java
@@ -1,6 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.node.admin.maintenance.identity;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.yahoo.vespa.athenz.api.AthenzService;
import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient;
import com.yahoo.vespa.athenz.client.zts.InstanceIdentity;
@@ -9,7 +11,7 @@ import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider;
import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper;
import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocumentClient;
import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument;
-import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId;
+import com.yahoo.vespa.athenz.identityprovider.api.bindings.SignedIdentityDocumentEntity;
import com.yahoo.vespa.athenz.identityprovider.client.DefaultIdentityDocumentClient;
import com.yahoo.vespa.athenz.identityprovider.client.InstanceCsrGenerator;
import com.yahoo.vespa.athenz.tls.AthenzIdentityVerifier;
@@ -19,9 +21,9 @@ import com.yahoo.vespa.athenz.tls.KeyUtils;
import com.yahoo.vespa.athenz.tls.Pkcs10Csr;
import com.yahoo.vespa.athenz.tls.SslContextBuilder;
import com.yahoo.vespa.athenz.tls.X509CertificateUtils;
+import com.yahoo.vespa.athenz.utils.SiaUtils;
import com.yahoo.vespa.hosted.dockerapi.ContainerName;
import com.yahoo.vespa.hosted.node.admin.component.Environment;
-import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeSpec;
import com.yahoo.vespa.hosted.node.admin.util.PrefixLogger;
import javax.net.ssl.SSLContext;
@@ -38,7 +40,6 @@ import java.security.cert.X509Certificate;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
-import java.util.Set;
import static java.util.Collections.singleton;
@@ -53,12 +54,15 @@ public class AthenzCredentialsMaintainer {
private static final Duration REFRESH_PERIOD = Duration.ofDays(1);
private static final Path CONTAINER_SIA_DIRECTORY = Paths.get("/var/lib/sia");
+ private static final ObjectMapper mapper = new ObjectMapper().registerModule(new JavaTimeModule());
+
private final boolean enabled;
private final PrefixLogger log;
private final String hostname;
private final Path trustStorePath;
private final Path privateKeyFile;
private final Path certificateFile;
+ private final Path identityDocumentFile;
private final AthenzService containerIdentity;
private final URI ztsEndpoint;
private final Clock clock;
@@ -66,8 +70,6 @@ public class AthenzCredentialsMaintainer {
private final IdentityDocumentClient identityDocumentClient;
private final InstanceCsrGenerator csrGenerator;
private final AthenzService configserverIdentity;
- private final String zoneRegion;
- private final String zoneEnvironment;
public AthenzCredentialsMaintainer(String hostname,
Environment environment,
@@ -82,8 +84,9 @@ public class AthenzCredentialsMaintainer {
this.configserverIdentity = environment.getConfigserverAthenzIdentity();
this.csrGenerator = new InstanceCsrGenerator(environment.getCertificateDnsSuffix());
this.trustStorePath = environment.getTrustStorePath();
- this.privateKeyFile = getPrivateKeyFile(containerSiaDirectory, containerIdentity);
- this.certificateFile = getCertificateFile(containerSiaDirectory, containerIdentity);
+ this.privateKeyFile = SiaUtils.getPrivateKeyFile(containerSiaDirectory, containerIdentity);
+ this.certificateFile = SiaUtils.getCertificateFile(containerSiaDirectory, containerIdentity);
+ this.identityDocumentFile = containerSiaDirectory.resolve("vespa-node-identity-document.json");
this.hostIdentityProvider = hostIdentityProvider;
this.identityDocumentClient =
new DefaultIdentityDocumentClient(
@@ -91,15 +94,12 @@ public class AthenzCredentialsMaintainer {
hostIdentityProvider,
new AthenzIdentityVerifier(singleton(configserverIdentity)));
this.clock = Clock.systemUTC();
- this.zoneRegion = environment.getRegion();
- this.zoneEnvironment = environment.getEnvironment();
}
/**
- * @param nodeSpec Node specification
* @return Returns true if credentials were updated
*/
- public boolean converge(NodeSpec nodeSpec) {
+ public boolean converge() {
try {
if (!enabled) {
log.debug("Feature disabled on this host - not fetching certificate");
@@ -107,26 +107,25 @@ public class AthenzCredentialsMaintainer {
}
log.debug("Checking certificate");
Instant now = clock.instant();
- VespaUniqueInstanceId instanceId = getVespaUniqueInstanceId(nodeSpec);
- Set<String> ipAddresses = nodeSpec.getIpAddresses();
- if (!Files.exists(privateKeyFile) || !Files.exists(certificateFile)) {
- log.info("Certificate and/or private key file does not exist");
+ if (!Files.exists(privateKeyFile) || !Files.exists(certificateFile) || !Files.exists(identityDocumentFile)) {
+ log.info("Certificate/private key/identity document file does not exist");
Files.createDirectories(privateKeyFile.getParent());
Files.createDirectories(certificateFile.getParent());
- registerIdentity(instanceId, ipAddresses);
+ Files.createDirectories(identityDocumentFile.getParent());
+ registerIdentity();
return true;
}
X509Certificate certificate = readCertificateFromFile();
Instant expiry = certificate.getNotAfter().toInstant();
if (isCertificateExpired(expiry, now)) {
log.info(String.format("Certificate has expired (expiry=%s)", expiry.toString()));
- registerIdentity(instanceId, ipAddresses);
+ registerIdentity();
return true;
}
Duration age = Duration.between(certificate.getNotBefore().toInstant(), now);
if (shouldRefreshCredentials(age)) {
log.info(String.format("Certificate is ready to be refreshed (age=%s)", age.toString()));
- refreshIdentity(instanceId, ipAddresses);
+ refreshIdentity();
return true;
}
log.debug("Certificate is still valid");
@@ -148,19 +147,6 @@ public class AthenzCredentialsMaintainer {
}
}
- private VespaUniqueInstanceId getVespaUniqueInstanceId(NodeSpec nodeSpec) {
- NodeSpec.Membership membership = nodeSpec.getMembership().get();
- NodeSpec.Owner owner = nodeSpec.getOwner().get();
- return new VespaUniqueInstanceId(
- membership.getIndex(),
- membership.getClusterId(),
- owner.getInstance(),
- owner.getApplication(),
- owner.getTenant(),
- zoneRegion,
- zoneEnvironment);
- }
-
private boolean shouldRefreshCredentials(Duration age) {
return age.compareTo(REFRESH_PERIOD) >= 0;
}
@@ -174,32 +160,32 @@ public class AthenzCredentialsMaintainer {
return now.isAfter(expiry.minus(EXPIRY_MARGIN));
}
- private void registerIdentity(VespaUniqueInstanceId instanceId, Set<String> ipAddresses) {
+ private void registerIdentity() {
KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA);
- Pkcs10Csr csr = csrGenerator.generateCsr(containerIdentity, instanceId, ipAddresses, keyPair);
SignedIdentityDocument signedIdentityDocument = identityDocumentClient.getNodeIdentityDocument(hostname);
+ Pkcs10Csr csr = csrGenerator.generateCsr(
+ containerIdentity, signedIdentityDocument.providerUniqueId(), signedIdentityDocument.ipAddresses(), keyPair);
try (ZtsClient ztsClient = new DefaultZtsClient(ztsEndpoint, hostIdentityProvider)) {
InstanceIdentity instanceIdentity =
ztsClient.registerInstance(
configserverIdentity,
containerIdentity,
- instanceId.asDottedString(),
+ signedIdentityDocument.providerUniqueId().asDottedString(),
EntityBindingsMapper.toAttestationData(signedIdentityDocument),
false,
csr);
+ writeIdentityDocument(signedIdentityDocument);
writePrivateKeyAndCertificate(keyPair.getPrivate(), instanceIdentity.certificate());
log.info("Instance successfully registered and credentials written to file");
} catch (IOException e) {
throw new UncheckedIOException(e);
- } catch (Exception e) {
- // TODO Change close() in ZtsClient to not throw checked exception
- throw new RuntimeException(e);
}
}
- private void refreshIdentity(VespaUniqueInstanceId instanceId, Set<String> ipAddresses) {
+ private void refreshIdentity() {
+ SignedIdentityDocument identityDocument = readIdentityDocument();
KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA);
- Pkcs10Csr csr = csrGenerator.generateCsr(containerIdentity, instanceId, ipAddresses, keyPair);
+ Pkcs10Csr csr = csrGenerator.generateCsr(containerIdentity, identityDocument.providerUniqueId(), identityDocument.ipAddresses(), keyPair);
SSLContext containerIdentitySslContext =
new SslContextBuilder()
.withKeyStore(privateKeyFile.toFile(), certificateFile.toFile())
@@ -210,16 +196,34 @@ public class AthenzCredentialsMaintainer {
ztsClient.refreshInstance(
configserverIdentity,
containerIdentity,
- instanceId.asDottedString(),
+ identityDocument.providerUniqueId().asDottedString(),
false,
csr);
writePrivateKeyAndCertificate(keyPair.getPrivate(), instanceIdentity.certificate());
log.info("Instance successfully refreshed and credentials written to file");
} catch (IOException e) {
throw new UncheckedIOException(e);
- } catch (Exception e) {
- // TODO Change close() in ZtsClient to not throw checked exception
- throw new RuntimeException(e);
+ }
+ }
+
+ private SignedIdentityDocument readIdentityDocument() {
+ try {
+ SignedIdentityDocumentEntity entity = mapper.readValue(identityDocumentFile.toFile(), SignedIdentityDocumentEntity.class);
+ return EntityBindingsMapper.toSignedIdentityDocument(entity);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ private void writeIdentityDocument(SignedIdentityDocument signedIdentityDocument) {
+ try {
+ SignedIdentityDocumentEntity entity =
+ EntityBindingsMapper.toSignedIdentityDocumentEntity(signedIdentityDocument);
+ Path tempIdentityDocumentFile = toTempPath(identityDocumentFile);
+ mapper.writeValue(tempIdentityDocumentFile.toFile(), entity);
+ Files.move(tempIdentityDocumentFile, identityDocumentFile, StandardCopyOption.ATOMIC_MOVE);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
}
}
@@ -237,18 +241,4 @@ public class AthenzCredentialsMaintainer {
return Paths.get(file.toAbsolutePath().toString() + ".tmp");
}
- // TODO Move to vespa-athenz
- private static Path getPrivateKeyFile(Path root, AthenzService service) {
- return root
- .resolve("keys")
- .resolve(String.format("%s.%s.key.pem", service.getDomain().getName(), service.getName()));
- }
-
- // TODO Move to vespa-athenz
- private static Path getCertificateFile(Path root, AthenzService service) {
- return root
- .resolve("certs")
- .resolve(String.format("%s.%s.cert.pem", service.getDomain().getName(), service.getName()));
- }
-
}
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java
index 7fa9a90b744..5f1b7aefcfe 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java
@@ -498,7 +498,7 @@ public class NodeAgentImpl implements NodeAgent {
runLocalResumeScriptIfNeeded(node);
- athenzCredentialsMaintainer.converge(node);
+ athenzCredentialsMaintainer.converge();
doBeforeConverge(node);
diff --git a/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoredumpHandler.java b/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoredumpHandler.java
index 99dfdb48334..63c74c17dd5 100644
--- a/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoredumpHandler.java
+++ b/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoredumpHandler.java
@@ -72,7 +72,7 @@ class CoredumpHandler {
FileHelper.deleteDirectories(doneCoredumpsPath, Duration.ofDays(10), Optional.empty());
}
- private void handleNewCoredumps() throws IOException {
+ private void handleNewCoredumps() {
Path processingCoredumps = enqueueCoredumps();
processAndReportCoredumps(processingCoredumps);
}
@@ -82,12 +82,12 @@ class CoredumpHandler {
* Moves a coredump to a new directory under the processing/ directory. Limit to only processing
* one coredump at the time, starting with the oldest.
*/
- Path enqueueCoredumps() throws IOException {
+ Path enqueueCoredumps() {
Path processingCoredumpsPath = coredumpsPath.resolve(PROCESSING_DIRECTORY_NAME);
processingCoredumpsPath.toFile().mkdirs();
- if (Files.list(processingCoredumpsPath).count() > 0) return processingCoredumpsPath;
+ if (!FileHelper.listContentsOfDirectory(processingCoredumpsPath).isEmpty()) return processingCoredumpsPath;
- Files.list(coredumpsPath)
+ FileHelper.listContentsOfDirectory(coredumpsPath).stream()
.filter(path -> path.toFile().isFile() && ! path.getFileName().toString().startsWith("."))
.min((Comparator.comparingLong(o -> o.toFile().lastModified())))
.ifPresent(coredumpPath -> {
@@ -101,10 +101,10 @@ class CoredumpHandler {
return processingCoredumpsPath;
}
- void processAndReportCoredumps(Path processingCoredumpsPath) throws IOException {
+ void processAndReportCoredumps(Path processingCoredumpsPath) {
doneCoredumpsPath.toFile().mkdirs();
- Files.list(processingCoredumpsPath)
+ FileHelper.listContentsOfDirectory(processingCoredumpsPath).stream()
.filter(path -> path.toFile().isDirectory())
.forEach(coredumpDirectory -> {
try {
@@ -130,7 +130,7 @@ class CoredumpHandler {
String collectMetadata(Path coredumpDirectory, Map<String, Object> nodeAttributes) throws IOException {
Path metadataPath = coredumpDirectory.resolve(METADATA_FILE_NAME);
if (!Files.exists(metadataPath)) {
- Path coredumpPath = Files.list(coredumpDirectory).findFirst()
+ Path coredumpPath = FileHelper.listContentsOfDirectory(coredumpDirectory).stream().findFirst()
.orElseThrow(() -> new RuntimeException("No coredump file found in processing directory " + coredumpDirectory));
Map<String, Object> metadata = coreCollector.collect(coredumpPath, installStatePath);
metadata.putAll(nodeAttributes);
diff --git a/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/FileHelper.java b/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/FileHelper.java
index ae872042853..7b93e7ad98d 100644
--- a/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/FileHelper.java
+++ b/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/FileHelper.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.hosted.node.maintainer;
import java.io.IOException;
+import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.NoSuchFileException;
@@ -63,7 +64,7 @@ public class FileHelper {
throw new IllegalArgumentException("Number of files to keep must be a positive number");
}
- List<Path> pathsInDeleteDir = Files.list(basePath)
+ List<Path> pathsInDeleteDir = listContentsOfDirectory(basePath).stream()
.filter(Files::isRegularFile)
.sorted(Comparator.comparing(FileHelper::getLastModifiedTime))
.skip(nMostRecentToKeep)
@@ -153,13 +154,16 @@ public class FileHelper {
return pattern == null || pattern.matcher(path.getFileName().toString()).find();
}
- static List<Path> listContentsOfDirectory(Path basePath) {
+ /**
+ * @return list all files in a directory, returns empty list if directory does not exist
+ */
+ public static List<Path> listContentsOfDirectory(Path basePath) {
try (Stream<Path> directoryStream = Files.list(basePath)) {
return directoryStream.collect(Collectors.toList());
} catch (NoSuchFileException ignored) {
return Collections.emptyList();
} catch (IOException e) {
- throw new RuntimeException("Failed to list contents of directory " + basePath.toAbsolutePath(), e);
+ throw new UncheckedIOException("Failed to list contents of directory " + basePath.toAbsolutePath(), e);
}
}
@@ -167,7 +171,7 @@ public class FileHelper {
try {
return Files.getLastModifiedTime(path, LinkOption.NOFOLLOW_LINKS);
} catch (IOException e) {
- throw new RuntimeException("Failed to get last modified time of " + path.toAbsolutePath(), e);
+ throw new UncheckedIOException("Failed to get last modified time of " + path.toAbsolutePath(), e);
}
}
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/GroupPreparer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/GroupPreparer.java
index 93704d244b5..d31b4438a38 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/GroupPreparer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/GroupPreparer.java
@@ -9,7 +9,6 @@ import com.yahoo.transaction.Mutex;
import com.yahoo.vespa.hosted.provision.Node;
import com.yahoo.vespa.hosted.provision.NodeRepository;
-import java.time.Clock;
import java.util.List;
/**
@@ -67,6 +66,7 @@ public class GroupPreparer {
allocation.offer(prioritizer.prioritize());
if (! allocation.fullfilled())
throw new OutOfCapacityException("Could not satisfy " + requestedNodes + " for " + cluster +
+ " in " + application.toShortString() +
outOfCapacityDetails(allocation));
// Extend reservation for already reserved nodes
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DockerProvisioningTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DockerProvisioningTest.java
index e1b1d74c6d0..2cabee98c0d 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DockerProvisioningTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DockerProvisioningTest.java
@@ -160,13 +160,13 @@ public class DockerProvisioningTest {
assertEquals(setOf("host1", "host2"), hostsOf(tester.getNodes(application1, Node.State.active)));
try {
- ApplicationId application2 = tester.makeApplicationId();
+ ApplicationId application2 = ApplicationId.from("tenant1", "app1", "default");
prepareAndActivate(application2, 3, false, tester);
fail("Expected allocation failure");
}
catch (Exception e) {
assertEquals("No room for 3 nodes as 2 of 4 hosts are exclusive",
- "Could not satisfy request for 3 nodes of flavor 'dockerSmall' for container cluster 'myContainer' group 0 6.39: Not enough nodes available due to host exclusivity constraints.",
+ "Could not satisfy request for 3 nodes of flavor 'dockerSmall' for container cluster 'myContainer' group 0 6.39 in tenant1.app1: Not enough nodes available due to host exclusivity constraints.",
e.getMessage());
}
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java
index c0cead74f5f..11c7832091b 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java
@@ -29,6 +29,7 @@ import java.security.cert.X509Certificate;
import java.time.Instant;
import java.util.Optional;
+import static com.yahoo.vespa.athenz.identityprovider.api.IdentityType.*;
import static com.yahoo.vespa.athenz.tls.KeyAlgorithm.RSA;
import static com.yahoo.vespa.athenz.tls.SignatureAlgorithm.SHA256_WITH_RSA;
import static java.util.Collections.emptySet;
@@ -161,7 +162,7 @@ public class NodeIdentifierTest {
Pkcs10Csr csr = Pkcs10CsrBuilder
.fromKeypair(new X500Principal("CN=" + TENANT_NODE_IDENTITY), KEYPAIR, SHA256_WITH_RSA)
.build();
- VespaUniqueInstanceId vespaUniqueInstanceId = new VespaUniqueInstanceId(clusterIndex, clusterId, INSTANCE_ID, application, tenant, region, environment);
+ VespaUniqueInstanceId vespaUniqueInstanceId = new VespaUniqueInstanceId(clusterIndex, clusterId, INSTANCE_ID, application, tenant, region, environment, NODE);
X509Certificate certificate = X509CertificateBuilder
.fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1)
.addSubjectAlternativeName(vespaUniqueInstanceId.asDottedString() + ".instanceid.athenz.provider-name.vespa.yahoo.cloud")
diff --git a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/ServiceMonitorInstanceLookupService.java b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/ServiceMonitorInstanceLookupService.java
index a09ec29dada..d1d5f3e8c95 100644
--- a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/ServiceMonitorInstanceLookupService.java
+++ b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/ServiceMonitorInstanceLookupService.java
@@ -46,7 +46,7 @@ public class ServiceMonitorInstanceLookupService implements InstanceLookupServic
return Optional.empty();
}
if (applicationInstancesUsingHost.size() > 1) {
- throw new AssertionError(
+ throw new IllegalStateException(
"Major assumption broken: Multiple application instances contain host " + hostName.s()
+ ": " + applicationInstancesUsingHost);
}
diff --git a/parent/pom.xml b/parent/pom.xml
index 411cc5ede9e..10e93d4ffbf 100644
--- a/parent/pom.xml
+++ b/parent/pom.xml
@@ -81,7 +81,7 @@
<plugin>
<groupId>org.apache.felix</groupId>
<artifactId>maven-bundle-plugin</artifactId>
- <version>2.4.0</version>
+ <version>3.5.0</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
@@ -498,11 +498,6 @@
<version>${antlr4.version}</version>
</dependency>
<dependency>
- <groupId>org.apache.aries.spifly</groupId>
- <artifactId>org.apache.aries.spifly.dynamic.bundle</artifactId>
- <version>${aries.spifly.version}</version>
- </dependency>
- <dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.1</version>
@@ -686,9 +681,6 @@
<properties>
<antlr.version>3.5.2</antlr.version>
<antlr4.version>4.5</antlr4.version>
- <aries.spifly.version>1.0.8</aries.spifly.version>
- <aries.util.version>1.0.0</aries.util.version>
- <asm-debug-all.version>5.0.3</asm-debug-all.version>
<!-- Athenz dependencies. Make sure these dependencies matches those in Vespa's internal repositories -->
<athenz.version>1.7.43</athenz.version>
<commons-lang.version>2.6</commons-lang.version>
diff --git a/searchcore/src/tests/proton/matching/query_test.cpp b/searchcore/src/tests/proton/matching/query_test.cpp
index 9adb86147b6..61823a17f09 100644
--- a/searchcore/src/tests/proton/matching/query_test.cpp
+++ b/searchcore/src/tests/proton/matching/query_test.cpp
@@ -107,6 +107,8 @@ class Test : public vespalib::TestApp {
void requireThatParallelWandBlueprintsAreCreatedCorrectly();
void requireThatWhiteListBlueprintCanBeUsed();
void requireThatSameElementTermsAreProperlyPrefixed();
+ void requireThatSameElementDoesNotAllocateMatchData();
+ void requireThatSameElementIteratorsCanBeBuilt();
public:
~Test();
@@ -181,12 +183,7 @@ Node::UP buildQueryTree(const ViewResolver &resolver,
query_builder.addPhrase(2, field, 7, Weight(0));
query_builder.addStringTerm(phrase_term, field, 8, Weight(0));
query_builder.addStringTerm(phrase_term, field, 9, Weight(0));
-#if 0
- //Todo add testing when SameElement blueprints are ready
- query_builder.addSameElement(2, field);
- query_builder.addStringTerm(string_term, field, 10, Weight(0));
- query_builder.addStringTerm(prefix_term, field, 11, Weight(0));
-#endif
+
Node::UP node = query_builder.build();
ResolveViewVisitor visitor(resolver, idxEnv);
@@ -194,6 +191,19 @@ Node::UP buildQueryTree(const ViewResolver &resolver,
return node;
}
+Node::UP buildSameElementQueryTree(const ViewResolver &resolver,
+ const search::fef::IIndexEnvironment &idxEnv)
+{
+ QueryBuilder<ProtonNodeTypes> query_builder;
+ query_builder.addSameElement(2, field);
+ query_builder.addStringTerm(string_term, field, 0, Weight(0));
+ query_builder.addStringTerm(prefix_term, field, 1, Weight(0));
+ Node::UP node = query_builder.build();
+ ResolveViewVisitor visitor(resolver, idxEnv);
+ node->accept(visitor);
+ return node;
+}
+
void Test::requireThatMatchDataIsReserved() {
Node::UP node = buildQueryTree(ViewResolver(), plain_index_env);
@@ -883,6 +893,7 @@ make_same_element_stack_dump(const vespalib::string &prefix, const vespalib::str
query->accept(sem);
return query;
}
+
void
Test::requireThatSameElementTermsAreProperlyPrefixed()
{
@@ -915,6 +926,32 @@ Test::requireThatSameElementTermsAreProperlyPrefixed()
EXPECT_EQUAL(dynamic_cast<ProtonStringTerm *>(root->getChildren()[1])->getView(), "abc.abc.f2");
}
+void
+Test::requireThatSameElementDoesNotAllocateMatchData()
+{
+ Node::UP node = buildSameElementQueryTree(ViewResolver(), plain_index_env);
+ MatchDataLayout mdl;
+ MatchDataReserveVisitor visitor(mdl);
+ node->accept(visitor);
+ MatchData::UP match_data = mdl.createMatchData();
+ EXPECT_EQUAL(0u, match_data->getNumTermFields());
+}
+
+void
+Test::requireThatSameElementIteratorsCanBeBuilt() {
+ Node::UP node = buildSameElementQueryTree(ViewResolver(), plain_index_env);
+ FakeSearchContext context(10);
+ context.addIdx(0).idx(0).getFake()
+ .addResult(field, string_term, FakeResult()
+ .doc(4).elem(1).pos(0).doc(8).elem(1).pos(0))
+ .addResult(field, prefix_term, FakeResult()
+ .doc(4).elem(2).pos(0).doc(8).elem(1).pos(1));
+ SearchIterator::UP iterator = getIterator(*node, context);
+ ASSERT_TRUE(iterator.get());
+ EXPECT_TRUE(!iterator->seek(4));
+ EXPECT_TRUE(iterator->seek(8));
+}
+
Test::~Test() = default;
int
@@ -937,7 +974,6 @@ Test::Main()
TEST_CALL(requireThatNearIteratorsCanBeBuilt);
TEST_CALL(requireThatONearIteratorsCanBeBuilt);
TEST_CALL(requireThatPhraseIteratorsCanBeBuilt);
- //TODO Add SameElement testing
TEST_CALL(requireThatUnknownFieldActsEmpty);
TEST_CALL(requireThatIllegalFieldsAreIgnored);
TEST_CALL(requireThatQueryGluesEverythingTogether);
@@ -949,7 +985,8 @@ Test::Main()
TEST_CALL(requireThatParallelWandBlueprintsAreCreatedCorrectly);
TEST_CALL(requireThatWhiteListBlueprintCanBeUsed);
TEST_CALL(requireThatSameElementTermsAreProperlyPrefixed);
-
+ TEST_CALL(requireThatSameElementDoesNotAllocateMatchData);
+ TEST_CALL(requireThatSameElementIteratorsCanBeBuilt);
TEST_DONE();
}
diff --git a/searchcore/src/tests/proton/matching/querynodes_test.cpp b/searchcore/src/tests/proton/matching/querynodes_test.cpp
index 7b6fdd1ae88..6607019cccc 100644
--- a/searchcore/src/tests/proton/matching/querynodes_test.cpp
+++ b/searchcore/src/tests/proton/matching/querynodes_test.cpp
@@ -25,6 +25,7 @@
#include <vespa/searchlib/queryeval/ranksearch.h>
#include <vespa/searchlib/queryeval/searchiterator.h>
#include <vespa/searchlib/queryeval/simple_phrase_search.h>
+#include <vespa/searchlib/queryeval/same_element_search.h>
#include <vespa/searchlib/queryeval/sourceblendersearch.h>
#include <vespa/searchlib/queryeval/fake_search.h>
#include <vespa/searchlib/queryeval/fake_requestcontext.h>
@@ -39,28 +40,30 @@ using search::fef::FieldInfo;
using search::fef::FieldType;
using search::fef::MatchData;
using search::fef::MatchDataLayout;
-using search::fef::TermFieldMatchData;
using search::fef::TermFieldHandle;
+using search::fef::TermFieldMatchData;
using search::fef::TermFieldMatchDataArray;
using search::fef::test::IndexEnvironment;
using search::query::Node;
using search::query::QueryBuilder;
+using search::queryeval::AndNotSearch;
+using search::queryeval::AndSearch;
+using search::queryeval::Blueprint;
+using search::queryeval::EmptySearch;
+using search::queryeval::FakeRequestContext;
+using search::queryeval::FakeResult;
+using search::queryeval::FakeSearch;
+using search::queryeval::FieldSpec;
using search::queryeval::ISourceSelector;
using search::queryeval::NearSearch;
using search::queryeval::ONearSearch;
using search::queryeval::OrSearch;
-using search::queryeval::AndSearch;
-using search::queryeval::AndNotSearch;
using search::queryeval::RankSearch;
-using search::queryeval::Blueprint;
+using search::queryeval::SameElementSearch;
using search::queryeval::SearchIterator;
-using search::queryeval::SourceBlenderSearch;
-using search::queryeval::FieldSpec;
using search::queryeval::Searchable;
-using search::queryeval::FakeSearch;
-using search::queryeval::FakeResult;
-using search::queryeval::FakeRequestContext;
using search::queryeval::SimplePhraseSearch;
+using search::queryeval::SourceBlenderSearch;
using std::string;
using std::vector;
using namespace proton::matching;
@@ -287,6 +290,20 @@ SearchIterator *getParent<ONear>(SearchIterator *a, SearchIterator *b) {
}
template <>
+SearchIterator *getParent<SameElement>(SearchIterator *a, SearchIterator *b) {
+ std::vector<SearchIterator::UP> children;
+ children.emplace_back(a);
+ children.emplace_back(b);
+ TermFieldMatchDataArray data;
+ static TermFieldMatchData tmd;
+ // we only check how many term/field combinations
+ // are below the SameElement parent:
+ // two terms searching in one index field
+ data.add(&tmd).add(&tmd);
+ return new SameElementSearch(nullptr, std::move(children), data, true);
+}
+
+template <>
SearchIterator *getParent<Or>(SearchIterator *a, SearchIterator *b) {
return getSimpleParent<OrSearch>(a, b);
}
@@ -422,6 +439,7 @@ void checkProperBlending() {
TEST_DO(checkOneFieldNoAttributesOneIndex<T>());
}
+
template <typename T>
void checkProperBlendingWithParent() {
IteratorStructureTest structure_test;
@@ -454,6 +472,24 @@ void checkProperBlendingWithParent() {
EXPECT_EQUAL(expected->asString(), structure_test.getIteratorAsString<T>());
}
+template <>
+void checkProperBlendingWithParent<SameElement>() {
+ using T = SameElement;
+ IteratorStructureTest structure_test;
+ structure_test.setFieldCount(1);
+ structure_test.setAttributeCount(0);
+ structure_test.setIndexCount(2);
+
+ SearchIterator::UP expected(
+ getParent<T>(Blender()
+ .add(SourceId(0), getTerm(phrase_term1, field[0], source_tag[0]))
+ .add(SourceId(1), getTerm(phrase_term1, field[0], source_tag[1])),
+ Blender(bothStrict<T>())
+ .add(SourceId(0), getTerm(phrase_term2, field[0], source_tag[0]))
+ .add(SourceId(1), getTerm(phrase_term2, field[0], source_tag[1]))));
+ EXPECT_EQUAL(expected->asString(), structure_test.getIteratorAsString<T>());
+}
+
TEST("requireThatTermNodeSearchIteratorsGetProperBlending") {
TEST_DO(checkProperBlending<Term>());
}
@@ -463,8 +499,7 @@ TEST("requireThatPhrasesGetProperBlending") {
}
TEST("requireThatSameElementGetProperBlending") {
- //TODO SameEelement needs proper testing/implementation
- //TEST_DO(checkProperBlending<SameElement>());
+ TEST_DO(checkProperBlendingWithParent<SameElement>());
}
TEST("requireThatNearGetProperBlending") {
diff --git a/searchcore/src/tests/proton/matching/resolveviewvisitor_test.cpp b/searchcore/src/tests/proton/matching/resolveviewvisitor_test.cpp
index 5ea2bcc982b..4fd079949d5 100644
--- a/searchcore/src/tests/proton/matching/resolveviewvisitor_test.cpp
+++ b/searchcore/src/tests/proton/matching/resolveviewvisitor_test.cpp
@@ -136,6 +136,23 @@ TEST_F("require that equiv nodes resolve view from children", Fixture) {
EXPECT_EQUAL(field2, base.field(1).field_name);
}
+TEST_F("require that view is resolved for SameElement children", Fixture) {
+ ViewResolver resolver;
+ resolver.add(view, field1);
+
+ QueryBuilder<ProtonNodeTypes> builder;
+ builder.addSameElement(2, "");
+ ProtonStringTerm &my_term = builder.addStringTerm(term, view, 42, weight);
+ builder.addStringTerm(term, field2, 43, weight);
+ Node::UP node = builder.build();
+
+ ResolveViewVisitor visitor(resolver, f.index_environment);
+ node->accept(visitor);
+
+ ASSERT_EQUAL(1u, my_term.numFields());
+ EXPECT_EQUAL(field1, my_term.field(0).field_name);
+}
+
} // namespace
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index 721214f9e94..4b49f17f74e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -1,5 +1,4 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
@@ -13,76 +12,61 @@ import java.util.Map;
import java.util.regex.Pattern;
/**
- * The result of importing a TensorFlow model into Vespa.
- * - A set of signatures which are named collections of inputs and outputs.
- * - A set of named constant tensors represented by Variable nodes in TensorFlow.
- * - A list of warning messages.
+ * The result of importing a model (TensorFlow or ONNX) into Vespa.
*
* @author bratseth
*/
-// This object can be built incrementally within this package, but is immutable when observed from outside the package
-public class TensorFlowModel {
+public class ImportedModel {
- private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
+ private static final String defaultSignatureName = "default";
+ private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
private final String name;
+ private final Map<String, Signature> signatures = new HashMap<>();
+ private final Map<String, TensorType> arguments = new HashMap<>();
+ private final Map<String, Tensor> smallConstants = new HashMap<>();
+ private final Map<String, Tensor> largeConstants = new HashMap<>();
+ private final Map<String, RankingExpression> expressions = new HashMap<>();
+ private final Map<String, RankingExpression> macros = new HashMap<>();
+ private final Map<String, TensorType> requiredMacros = new HashMap<>();
+
/**
- * Creates a TensorFlow model
+ * Creates a new imported model.
*
* @param name the name of this mode, containing only characters in [A-Za-z0-9_]
*/
- public TensorFlowModel(String name) {
+ public ImportedModel(String name) {
if ( ! nameRegexp.matcher(name).matches())
- throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" +
- name + "'");
+ throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" +
+ name + "'");
this.name = name;
}
/** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
public String name() { return name; }
- private final Map<String, Signature> signatures = new HashMap<>();
- private final Map<String, TensorType> arguments = new HashMap<>();
- private final Map<String, Tensor> smallConstants = new HashMap<>();
- private final Map<String, Tensor> largeConstants = new HashMap<>();
- private final Map<String, RankingExpression> expressions = new HashMap<>();
- private final Map<String, RankingExpression> macros = new HashMap<>();
- private final Map<String, TensorType> requiredMacros = new HashMap<>();
-
- void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
- void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
- void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
- void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- void macro(String name, RankingExpression expression) { macros.put(name, expression); }
- void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
-
- /** Returns the given signature. If it does not already exist it is added to this. */
- Signature signature(String name) {
- return signatures.computeIfAbsent(name, Signature::new);
- }
-
/** Returns an immutable map of the arguments ("Placeholders") of this */
public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
/**
* Returns an immutable map of the small constants of this.
* These should have sizes up to a few kb at most, and correspond to constant
- * values given in the TensorFlow source.
+ * values given in the TensorFlow or ONNX source.
*/
public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
/**
* Returns an immutable map of the large constants of this.
- * These can have sizes in gigabytes and must be distributed to nodes separately from configuration,
- * and correspond to Variable files stored separately in TensorFlow.
+ * These can have sizes in gigabytes and must be distributed to nodes separately from configuration.
+ * For TensorFlow this corresponds to Variable files stored separately.
*/
public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
/**
- * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes
- * which are not Placeholders or Variables (which instead become respectively arguments and constants).
- * Note that only nodes recursively referenced by a placeholder are added.
+ * Returns an immutable map of the expressions of this - corresponding to graph nodes
+ * which are not Inputs/Placeholders or Variables (which instead become respectively arguments and constants).
+ * Note that only nodes recursively referenced by a placeholder/input are added.
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
@@ -95,9 +79,26 @@ public class TensorFlowModel {
/** Returns an immutable map of the signatures of this */
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
+ /** Returns the given signature. If it does not already exist it is added to this. */
+ Signature signature(String name) {
+ return signatures.computeIfAbsent(name, Signature::new);
+ }
+
+ /** Convenience method for returning a default signature */
+ Signature defaultSignature() { return signature(defaultSignatureName); }
+
+ void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
+ void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
+ void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
+ void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
+ void macro(String name, RankingExpression expression) { macros.put(name, expression); }
+ void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
+
/**
- * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types,
- * and outputs maps to expressions nodes.
+ * A signature is a set of named inputs and outputs, where the inputs maps to argument
+ * ("placeholder") names+types, and outputs maps to expressions nodes.
+ * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit
+ * concept of signatures. For now, we handle ONNX models as having a single signature.
*/
public class Signature {
@@ -107,19 +108,14 @@ public class TensorFlowModel {
private final Map<String, String> skippedOutputs = new HashMap<>();
private final List<String> importWarnings = new ArrayList<>();
- Signature(String name) {
+ public Signature(String name) {
this.name = name;
}
- void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
- void output(String name, String expressionName) { outputs.put(name, expressionName); }
- void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
- void importWarning(String warning) { importWarnings.add(warning); }
-
public String name() { return name; }
/** Returns the result this is part of */
- TensorFlowModel owner() { return TensorFlowModel.this; }
+ public ImportedModel owner() { return ImportedModel.this; }
/**
* Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
@@ -127,7 +123,7 @@ public class TensorFlowModel {
*/
public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
- /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */
+ /** Returns the type of the argument this input references */
public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); }
/** Returns an immutable list of the expression names of this */
@@ -144,12 +140,17 @@ public class TensorFlowModel {
*/
public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
- /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */
+ /** Returns the expression this output references */
public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); }
@Override
public String toString() { return "signature '" + name + "'"; }
+ void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
+ void output(String name, String expressionName) { outputs.put(name, expressionName); }
+ void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
+ void importWarning(String warning) { importWarnings.add(warning); }
+
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
new file mode 100644
index 00000000000..a658833b426
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -0,0 +1,242 @@
+package com.yahoo.searchlib.rankingexpression.integration.ml;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.yolean.Exceptions;
+
+import java.io.File;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.logging.Logger;
+
+/**
+ * Base class for importing ML models (ONNX/TensorFlow) as native Vespa
+ * ranking expressions. The general mechanism for import is for the
+ * specific ML platform import implementations to create an
+ * IntermediateGraph. This class offers common code to convert the
+ * IntermediateGraph to Vespa ranking expressions and macros.
+ *
+ * @author lesters
+ */
+public abstract class ModelImporter {
+
+ private static final Logger log = Logger.getLogger(ModelImporter.class.getName());
+
+ /**
+ * The main import function.
+ */
+ public abstract ImportedModel importModel(String modelName, String modelPath);
+
+ public ImportedModel importModel(String modelName, File modelDir) {
+ return importModel(modelName, modelDir.toString());
+ }
+
+ /**
+ * Takes an IntermediateGraph and converts it to a ImportedModel containing
+ * the actual Vespa ranking expressions.
+ */
+ static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) {
+ ImportedModel model = new ImportedModel(graph.name());
+
+ graph.optimize();
+
+ importSignatures(graph, model);
+ importExpressions(graph, model);
+ reportWarnings(graph, model);
+ logVariableTypes(graph);
+
+ return model;
+ }
+
+ private static void importSignatures(IntermediateGraph graph, ImportedModel model) {
+ for (String signatureName : graph.signatures()) {
+ ImportedModel.Signature signature = model.signature(signatureName);
+ for (Map.Entry<String, String> input : graph.inputs(signatureName).entrySet()) {
+ signature.input(input.getKey(), input.getValue());
+ }
+ for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) {
+ signature.output(output.getKey(), output.getValue());
+ }
+ }
+ }
+
+ private static boolean isSignatureInput(ImportedModel model, IntermediateOperation operation) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String inputName : signature.inputs().values()) {
+ if (inputName.equals(operation.name())) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ if (outputName.equals(operation.name())) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Convert intermediate representation to Vespa ranking expressions.
+ */
+ static void importExpressions(IntermediateGraph graph, ImportedModel model) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ try {
+ Optional<TensorFunction> function = importExpression(graph.get(outputName), model);
+ if (!function.isPresent()) {
+ signature.skippedOutput(outputName, "No valid output function could be found.");
+ }
+ }
+ catch (IllegalArgumentException e) {
+ signature.skippedOutput(outputName, Exceptions.toMessageString(e));
+ }
+ }
+ }
+ }
+
+ private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) {
+ if (!operation.type().isPresent()) {
+ return Optional.empty();
+ }
+ if (operation.isConstant()) {
+ return importConstant(operation, model);
+ }
+ importExpressionInputs(operation, model);
+ importRankingExpression(operation, model);
+ importArgumentExpression(operation, model);
+ importMacroExpression(operation, model);
+
+ return operation.function();
+ }
+
+ private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) {
+ operation.inputs().forEach(input -> importExpression(input, model));
+ }
+
+ private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
+ String name = operation.vespaName();
+ if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
+ return operation.function();
+ }
+
+ Value value = operation.getConstantValue().orElseThrow(() ->
+ new IllegalArgumentException("Operation '" + operation.vespaName() + "' " +
+ "is constant but does not have a value."));
+ if ( ! (value instanceof TensorValue)) {
+ return operation.function(); // scalar values are inserted directly into the expression
+ }
+
+ Tensor tensor = value.asTensor();
+ if (tensor.type().rank() == 0) {
+ model.smallConstant(name, tensor);
+ } else {
+ model.largeConstant(name, tensor);
+ }
+ return operation.function();
+ }
+
+ private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) {
+ if (operation.function().isPresent()) {
+ String name = operation.name();
+ if (!model.expressions().containsKey(name)) {
+ TensorFunction function = operation.function().get();
+
+ if (isSignatureOutput(model, operation)) {
+ OrderedTensorType operationType = operation.type().get();
+ OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
+ if ( ! operationType.equals(standardNamingType)) {
+ List<String> renameFrom = operationType.dimensionNames();
+ List<String> renameTo = standardNamingType.dimensionNames();
+ function = new Rename(function, renameFrom, renameTo);
+ }
+ }
+
+ try {
+ // We add all intermediate nodes imported as separate expressions. Only
+ // those referenced from the output will be used. We parse the
+ // TensorFunction here to convert it to a RankingExpression tree.
+ model.expression(name, new RankingExpression(name, function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Imported function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
+ }
+ }
+
+ private static void importArgumentExpression(IntermediateOperation operation, ImportedModel model) {
+ if (operation.isInput()) {
+ // All inputs must have dimensions with standard naming convention: d0, d1, ...
+ OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
+ model.argument(operation.vespaName(), standardNamingConvention.type());
+ model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
+ }
+ }
+
+ private static void importMacroExpression(IntermediateOperation operation, ImportedModel model) {
+ if (operation.macro().isPresent()) {
+ TensorFunction function = operation.macro().get();
+ try {
+ model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Tensorflow function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
+ }
+
+ /**
+ * Add any import warnings to the signature in the ImportedModel.
+ */
+ private static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ reportWarnings(graph.get(outputName), model);
+ }
+ }
+ }
+
+ private static void reportWarnings(IntermediateOperation operation, ImportedModel model) {
+ for (String warning : operation.warnings()) {
+ model.defaultSignature().importWarning(warning);
+ }
+ for (IntermediateOperation input : operation.inputs()) {
+ reportWarnings(input, model);
+ }
+ }
+
+ /**
+ * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type.
+ * This allows users to learn the exact types (including dimension order after renaming) of the Variables
+ * such that these can be converted and fed to a parent document independently of the rest of the model
+ * for fast model weight updates.
+ */
+ private static void logVariableTypes(IntermediateGraph graph) {
+ for (IntermediateOperation operation : graph.operations()) {
+ if ( ! (operation instanceof Constant)) continue;
+ if ( ! operation.type().isPresent()) continue; // will not happen
+ log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() +
+ " of type " + operation.type().get());
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
new file mode 100644
index 00000000000..d3dd2a1d418
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
@@ -0,0 +1,30 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter;
+import onnx.Onnx;
+
+import java.io.FileInputStream;
+import java.io.IOException;
+
+/**
+ * Converts a ONNX model into a ranking expression and set of constants.
+ *
+ * @author lesters
+ */
+public class OnnxImporter extends ModelImporter {
+
+ @Override
+ public ImportedModel importModel(String modelName, String modelPath) {
+ try (FileInputStream inputStream = new FileInputStream(modelPath)) {
+ Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
+ IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
+ return convertIntermediateGraphToModel(graph);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
new file mode 100644
index 00000000000..ff584559a83
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
@@ -0,0 +1,47 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.ml;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
+import org.tensorflow.SavedModelBundle;
+
+import java.io.IOException;
+
+/**
+ * Converts a saved TensorFlow model into a ranking expression and set of constants.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+public class TensorFlowImporter extends ModelImporter {
+
+ /**
+ * Imports a saved TensorFlow model from a directory.
+ * The model should be saved as a .pbtxt or .pb file.
+ * The name of the model is taken as the db/pbtxt file name (not including the file ending).
+ *
+ * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
+ * @param modelDir the directory containing the TensorFlow model files to import
+ */
+ public ImportedModel importModel(String modelName, String modelDir) {
+ try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
+ return importModel(modelName, model);
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
+ }
+ }
+
+ /** Imports a TensorFlow model */
+ ImportedModel importModel(String modelName, SavedModelBundle model) {
+ try {
+ IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
+ return convertIntermediateGraphToModel(graph);
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
+ }
+ }
+
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java
index c5ac7ace0fc..e1294ec3e01 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java
@@ -1,7 +1,8 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.yolean.Exceptions;
import org.tensorflow.SavedModelBundle;
@@ -24,7 +25,7 @@ public class VariableConverter {
*/
public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) {
try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) {
- return JsonFormat.encode(TensorConverter.toVespaTensor(TensorFlowImporter.readVariable(tensorFlowVariableName,
+ return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName,
bundle),
OrderedTensorType.fromSpec(orderedTypeSpec)));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java
index 2524417cee0..38f1d2329e2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java
@@ -1,7 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
import java.util.ArrayDeque;
import java.util.ArrayList;
@@ -47,7 +47,7 @@ public class DimensionRenamer {
/**
* Add a constraint between dimension names.
*/
- public void addConstraint(String from, String to, Constraint pred, OnnxOperation operation) {
+ public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) {
Arc arc = new Arc(from, to, operation);
Arc opposite = arc.opposite();
constraints.put(arc, pred);
@@ -175,9 +175,9 @@ public class DimensionRenamer {
private final String from;
private final String to;
- private final OnnxOperation operation;
+ private final IntermediateOperation operation;
- Arc(String from, String to, OnnxOperation operation) {
+ Arc(String from, String to, IntermediateOperation operation) {
this.from = from;
this.to = to;
this.operation = operation;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java
new file mode 100644
index 00000000000..39a8b211d09
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java
@@ -0,0 +1,107 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Holds an intermediate representation of an imported ONNX or TensorFlow
+ * graph. After this intermediate representation is constructed, it is used to
+ * simplify and optimize the computational graph and then converted into the
+ * final ImportedModel that holds the Vespa ranking expressions for the model.
+ *
+ * @author lesters
+ */
+public class IntermediateGraph {
+
+ private final String modelName;
+ private final Map<String, IntermediateOperation> index = new HashMap<>();
+ private final Map<String, GraphSignature> signatures = new HashMap<>();
+
+ private static class GraphSignature {
+ final Map<String, String> inputs = new HashMap<>();
+ final Map<String, String> outputs = new HashMap<>();
+ }
+
+ public IntermediateGraph(String modelName) {
+ this.modelName = modelName;
+ }
+
+ public String name() {
+ return modelName;
+ }
+
+ public IntermediateOperation put(String key, IntermediateOperation operation) {
+ return index.put(key, operation);
+ }
+
+ public IntermediateOperation get(String key) {
+ return index.get(key);
+ }
+
+ public Set<String> signatures() {
+ return signatures.keySet();
+ }
+
+ public Map<String, String> inputs(String signature) {
+ return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs;
+ }
+
+ public Map<String, String> outputs(String signature) {
+ return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs;
+ }
+
+ public String defaultSignature() {
+ return "default";
+ }
+
+ public boolean alreadyImported(String key) {
+ return index.containsKey(key);
+ }
+
+ public Collection<IntermediateOperation> operations() {
+ return index.values();
+ }
+
+ public void optimize() {
+ renameDimensions();
+ }
+
+ /**
+ * Find dimension names to avoid excessive renaming while evaluating the model.
+ */
+ private void renameDimensions() {
+ DimensionRenamer renamer = new DimensionRenamer();
+ for (String signature : signatures()) {
+ for (String output : outputs(signature).values()) {
+ addDimensionNameConstraints(index.get(output), renamer);
+ }
+ }
+ renamer.solve();
+ for (String signature : signatures()) {
+ for (String output : outputs(signature).values()) {
+ renameDimensions(index.get(output), renamer);
+ }
+ }
+ }
+
+ private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.addDimensionNameConstraints(renamer);
+ }
+ }
+
+ private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.renameDimensions(renamer);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java
index 812e9b8d678..209d73a9f38 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java
@@ -1,9 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
import com.yahoo.tensor.TensorType;
-import onnx.Onnx;
+import com.yahoo.tensor.TensorTypeParser;
import java.util.ArrayList;
import java.util.Collections;
@@ -13,9 +13,9 @@ import java.util.stream.Collectors;
/**
* A Vespa tensor type is ordered by the lexicographical ordering of dimension
- * names. ONNX tensors have an explicit ordering of their dimensions.
+ * names. Imported tensors have an explicit ordering of their dimensions.
* During import, we need to track the Vespa dimension that matches the
- * corresponding ONNX dimension as the ordering can change after
+ * corresponding imported dimension as the ordering can change after
* dimension renaming. That is the purpose of this class.
*
* @author lesters
@@ -25,14 +25,14 @@ public class OrderedTensorType {
private final TensorType type;
private final List<TensorType.Dimension> dimensions;
- private final long[] innerSizesOnnx;
+ private final long[] innerSizesOriginal;
private final long[] innerSizesVespa;
private final int[] dimensionMap;
private OrderedTensorType(List<TensorType.Dimension> dimensions) {
this.dimensions = Collections.unmodifiableList(dimensions);
this.type = new TensorType.Builder(dimensions).build();
- this.innerSizesOnnx = new long[dimensions.size()];
+ this.innerSizesOriginal = new long[dimensions.size()];
this.innerSizesVespa = new long[dimensions.size()];
this.dimensionMap = createDimensionMap();
}
@@ -54,10 +54,10 @@ public class OrderedTensorType {
if (numDimensions == 0) {
return null;
}
- innerSizesOnnx[numDimensions - 1] = 1;
+ innerSizesOriginal[numDimensions - 1] = 1;
innerSizesVespa[numDimensions - 1] = 1;
for (int i = numDimensions - 1; --i >= 0; ) {
- innerSizesOnnx[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOnnx[i+1];
+ innerSizesOriginal[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOriginal[i+1];
innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1];
}
int[] mapping = new int[numDimensions];
@@ -74,11 +74,15 @@ public class OrderedTensorType {
return mapping;
}
+ public int dimensionMap(int originalIndex) {
+ return dimensionMap[originalIndex];
+ }
+
/**
- * When dimension ordering between Vespa and Onnx differs, i.e.
+ * When dimension ordering between Vespa and imported differs, i.e.
* after dimension renaming, use the dimension map to read in values
* so that they are correctly laid out in memory for Vespa.
- * Used when importing tensors from Onnx.
+ * Used when importing tensors.
*/
public int toDirectIndex(int index) {
if (dimensions.size() == 0) {
@@ -90,9 +94,9 @@ public class OrderedTensorType {
int directIndex = 0;
long rest = index;
for (int i = 0; i < dimensions.size(); ++i) {
- long address = rest / innerSizesOnnx[i];
+ long address = rest / innerSizesOriginal[i];
directIndex += innerSizesVespa[dimensionMap[i]] * address;
- rest %= innerSizesOnnx[i];
+ rest %= innerSizesOriginal[i];
}
return directIndex;
}
@@ -116,22 +120,6 @@ public class OrderedTensorType {
return true;
}
- public void verifyType(Onnx.TypeProto typeProto) {
- Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
- if (shape != null) {
- if (shape.getDimCount() != type.rank()) {
- throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
- }
- for (int onnxIndex = 0; onnxIndex < dimensions.size(); ++onnxIndex) {
- int vespaIndex = dimensionMap[onnxIndex];
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
- TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex);
- if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) {
- throw new IllegalArgumentException("TensorFlow dimensions of does not match Vespa dimensions");
- }
- }
- }
- }
public OrderedTensorType rename(DimensionRenamer renamer) {
List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
for (TensorType.Dimension dimension : dimensions) {
@@ -151,18 +139,13 @@ public class OrderedTensorType {
return new OrderedTensorType(renamedDimensions);
}
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
- return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
- }
-
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
- Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- Builder builder = new Builder(shape);
- for (int i = 0; i < shape.getDimCount(); ++ i) {
+ public OrderedTensorType rename(String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < dimensions.size(); ++ i) {
String dimensionName = dimensionPrefix + i;
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
- if (onnxDimension.getDimValue() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue()));
+ Optional<Long> dimSize = dimensions.get(i).size();
+ if (dimSize.isPresent() && dimSize.get() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dimSize.get()));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}
@@ -170,13 +153,13 @@ public class OrderedTensorType {
return builder.build();
}
- public static OrderedTensorType fromOnnxType(List<Long> dims, String dimensionPrefix) {
- Builder builder = new Builder();
- for (int i = 0; i < dims.size(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- Long dimSize = dims.get(i);
- if (dimSize >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
+ public static OrderedTensorType standardType(OrderedTensorType type) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < type.dimensions().size(); ++ i) {
+ TensorType.Dimension dim = type.dimensions().get(i);
+ String dimensionName = "d" + i;
+ if (dim.size().isPresent() && dim.size().get() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get()));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}
@@ -184,13 +167,46 @@ public class OrderedTensorType {
return builder.build();
}
- public static OrderedTensorType standardType(OrderedTensorType type) {
- Builder builder = new Builder();
- for (int i = 0; i < type.dimensions().size(); ++ i) {
- TensorType.Dimension dim = type.dimensions().get(i);
- String dimensionName = "d" + i;
- if (dim.size().isPresent() && dim.size().get() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get()));
+ public static Long tensorSize(TensorType type) {
+ Long size = 1L;
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ size *= dimensionSize(dimension);
+ }
+ return size;
+ }
+
+ public static Long dimensionSize(TensorType.Dimension dim) {
+ return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
+ }
+
+ /**
+ * Returns a string representation of this: A standard tensor type string where dimensions
+ * are listed in the order of this rather than in the natural order of their names.
+ */
+ @Override
+ public String toString() {
+ return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
+ }
+
+ /**
+ * Creates an instance from the string representation of this: A standard tensor type string
+ * where dimensions are listed in the order of this rather than the natural order of their names.
+ */
+ public static OrderedTensorType fromSpec(String typeSpec) {
+ return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
+ }
+
+ public static OrderedTensorType fromDimensionList(List<Long> dims) {
+ return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < dims.size(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Long dimSize = dims.get(i);
+ if (dimSize >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}
@@ -200,45 +216,13 @@ public class OrderedTensorType {
public static class Builder {
- private final Onnx.TensorShapeProto shape;
private final List<TensorType.Dimension> dimensions;
- public Builder(Onnx.TensorShapeProto shape) {
- this.shape = shape;
- this.dimensions = new ArrayList<>(shape.getDimCount());
- }
-
public Builder() {
- this.shape = null;
this.dimensions = new ArrayList<>();
}
public Builder add(TensorType.Dimension vespaDimension) {
- if (shape != null) {
- int index = dimensions.size();
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(index);
- long size = onnxDimension.getDimValue();
- if (size >= 0) {
- if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
- throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
- "dimension types");
- }
- if (!vespaDimension.size().isPresent()) {
- throw new IllegalArgumentException("Tensor dimension is indexed bound but does " +
- "not have a size");
- }
- if (vespaDimension.size().get() != size) {
- throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
- "dimension sizes. TensorFlow: " + size + " Vespa: " +
- vespaDimension.size().get());
- }
- } else {
- if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
- throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
- "dimension types");
- }
- }
- }
this.dimensions.add(vespaDimension);
return this;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
new file mode 100644
index 00000000000..3fe92440cae
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
@@ -0,0 +1,216 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import onnx.Onnx;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Converts an ONNX graph to a Vespa IntermediateGraph which is the basis
+ * for generating Vespa ranking expressions.
+ *
+ * @author lesters
+ */
+public class GraphImporter {
+
+ public static IntermediateOperation mapOperation(Onnx.NodeProto node,
+ List<IntermediateOperation> inputs,
+ IntermediateGraph graph) {
+ String nodeName = node.getName();
+ String modelName = graph.name();
+
+ switch (node.getOpType().toLowerCase()) {
+ case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
+ case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
+ case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
+ case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
+ case "concat": return new ConcatV2(modelName, nodeName, inputs);
+ case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
+ case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
+ case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
+ case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
+ case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater());
+ case "identity": return new Identity(modelName, nodeName, inputs);
+ case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less());
+ case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log());
+ case "matmul": return new MatMul(modelName, nodeName, inputs);
+ case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
+ case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min());
+ case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean());
+ case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
+ case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
+ case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
+ case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
+ case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
+ case "shape": return new Shape(modelName, nodeName, inputs);
+ case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
+ case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
+ case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
+ case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
+ case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
+ }
+
+ IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
+ op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
+ return op;
+ }
+
+ public static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) {
+ Onnx.GraphProto onnxGraph = model.getGraph();
+
+ IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
+ importOperations(onnxGraph, intermediateGraph);
+ verifyOutputTypes(onnxGraph, intermediateGraph);
+
+ return intermediateGraph;
+ }
+
+ private static void importOperations(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
+ for (Onnx.ValueInfoProto valueInfo : onnxGraph.getOutputList()) {
+ importOperation(valueInfo.getName(), onnxGraph, intermediateGraph);
+ }
+ }
+
+ private static IntermediateOperation importOperation(String name,
+ Onnx.GraphProto onnxGraph,
+ IntermediateGraph intermediateGraph) {
+ if (intermediateGraph.alreadyImported(name)) {
+ return intermediateGraph.get(name);
+ }
+ IntermediateOperation operation;
+ if (isArgumentTensor(name, onnxGraph)) {
+ Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph);
+ if (valueInfoProto == null)
+ throw new IllegalArgumentException("Could not find argument tensor: " + name);
+ OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType());
+ operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type);
+
+ intermediateGraph.inputs(intermediateGraph.defaultSignature())
+ .put(IntermediateOperation.namePartOf(name), operation.vespaName());
+
+ } else if (isConstantTensor(name, onnxGraph)) {
+ Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
+ OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList());
+ operation = new Constant(intermediateGraph.name(), name, defaultType);
+ operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
+
+ } else {
+ Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph);
+ List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph);
+ operation = mapOperation(node, inputs, intermediateGraph);
+
+ if (isOutputNode(name, onnxGraph)) {
+ intermediateGraph.outputs(intermediateGraph.defaultSignature())
+ .put(IntermediateOperation.namePartOf(name), operation.vespaName());
+ }
+ }
+ intermediateGraph.put(operation.vespaName(), operation);
+
+ return operation;
+ }
+
+ private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
+ Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
+ Onnx.TensorProto tensor = getConstantTensor(name, graph);
+ return value != null && tensor == null;
+ }
+
+ private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
+ Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
+ Onnx.TensorProto tensor = getConstantTensor(name, graph);
+ return value != null && tensor != null;
+ }
+
+ private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
+ for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) {
+ if (valueInfo.getName().equals(name)) {
+ return valueInfo;
+ }
+ }
+ return null;
+ }
+
+ private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) {
+ for (Onnx.TensorProto tensorProto : graph.getInitializerList()) {
+ if (tensorProto.getName().equals(name)) {
+ return tensorProto;
+ }
+ }
+ return null;
+ }
+
+ private static boolean isOutputNode(String name, Onnx.GraphProto graph) {
+ return getOutputNode(name, graph) != null;
+ }
+
+ private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
+ for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
+ if (valueInfo.getName().equals(name)) {
+ return valueInfo;
+ }
+ String nodeName = IntermediateOperation.namePartOf(valueInfo.getName());
+ if (nodeName.equals(name)) {
+ return valueInfo;
+ }
+ }
+ return null;
+ }
+
+ private static List<IntermediateOperation> importOperationInputs(Onnx.NodeProto node,
+ Onnx.GraphProto onnxGraph,
+ IntermediateGraph intermediateGraph) {
+ return node.getInputList().stream()
+ .map(nodeName -> importOperation(nodeName, onnxGraph, intermediateGraph))
+ .collect(Collectors.toList());
+ }
+
+ private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
+ for (String outputName : intermediateGraph.outputs(intermediateGraph.defaultSignature()).values()) {
+ IntermediateOperation operation = intermediateGraph.get(outputName);
+ Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, onnxGraph);
+ OrderedTensorType type = operation.type().orElseThrow(
+ () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."));
+ TypeConverter.verifyType(onnxNode.getType(), type);
+ }
+ }
+
+ private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
+ boolean hasPortNumber = nodeName.contains(":");
+ for (Onnx.NodeProto node : graph.getNodeList()) {
+ if (hasPortNumber) {
+ for (String outputName : node.getOutputList()) {
+ if (outputName.equals(nodeName)) {
+ return node;
+ }
+ }
+ } else if (node.getName().equals(nodeName)) {
+ return node;
+ }
+ }
+ throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
+ }
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
index 2912db03b5f..18856d4a25f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
@@ -1,17 +1,16 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
import com.google.protobuf.ByteString;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
import onnx.Onnx;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
-import java.util.List;
/**
* Converts Onnx tensors into Vespa tensors.
@@ -29,7 +28,6 @@ public class TensorConverter {
return builder.build();
}
- /* todo: support more types */
private static Values readValuesOf(Onnx.TensorProto tensorProto) {
if (tensorProto.hasRawData()) {
switch (tensorProto.getDataType()) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
new file mode 100644
index 00000000000..715c55d8323
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
@@ -0,0 +1,52 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import onnx.Onnx;
+
+/**
+ * Converts and verifies ONNX tensor types into Vespa tensor types.
+ *
+ * @author lesters
+ */
+public class TypeConverter {
+
+ public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) {
+ Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
+ if (shape != null) {
+ if (shape.getDimCount() != type.rank()) {
+ throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
+ }
+ for (int onnxIndex = 0; onnxIndex < type.dimensions().size(); ++onnxIndex) {
+ int vespaIndex = type.dimensionMap(onnxIndex);
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
+ TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
+ if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) {
+ throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions");
+ }
+ }
+ }
+ }
+
+ public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
+ return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
+ Onnx.TensorShapeProto shape = type.getTensorType().getShape();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
+ if (onnxDimension.getDimValue() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
index 1619c11427a..7fc2aae87d1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
@@ -1,28 +1,29 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
+import java.util.Collections;
import java.util.List;
-public class Placeholder extends TensorFlowOperation {
+public class Argument extends IntermediateOperation {
private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
- public Placeholder(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
- standardNamingType = OrderedTensorType.fromTensorFlowType(node);
+ public Argument(String modelName, String nodeName, OrderedTensorType type) {
+ super(modelName, nodeName, Collections.emptyList());
+ this.type = type.rename(vespaName() + "_");
+ standardNamingType = OrderedTensorType.standardType(type);
}
@Override
protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ return type;
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
index 4f5d61d75f9..1b8c62fe0e9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
@@ -1,38 +1,37 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class ConcatV2 extends TensorFlowOperation {
+public class ConcatV2 extends IntermediateOperation {
private String concatDimensionName;
- public ConcatV2(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
protected OrderedTensorType lazyGetType() {
- if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) {
+ if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
return null;
}
- TensorFlowOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
+ IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
if (!concatDimOp.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
"concat dimension must be a constant.");
}
Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor();
if (concatDimTensor.type().rank() != 0) {
- throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
"concat dimension must be a scalar.");
}
@@ -44,7 +43,7 @@ public class ConcatV2 extends TensorFlowOperation {
for (int i = 1; i < inputs.size() - 1; ++i) {
OrderedTensorType bType = inputs.get(i).type().get();
if (bType.rank() != aType.rank()) {
- throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
"inputs must have save rank.");
}
for (int j = 0; j < aType.rank(); ++j) {
@@ -53,13 +52,13 @@ public class ConcatV2 extends TensorFlowOperation {
if (j == concatDim) {
concatDimSize += dimSizeB;
} else if (dimSizeA != dimSizeB) {
- throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
"input dimension " + j + " differs in input tensors.");
}
}
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : aType.dimensions()) {
if (dimensionIndex == concatDim) {
@@ -75,7 +74,7 @@ public class ConcatV2 extends TensorFlowOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!inputs.stream().map(TensorFlowOperation::function).allMatch(Optional::isPresent)) {
+ if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) {
return null;
}
TensorFunction result = inputs.get(0).function().get();
@@ -88,7 +87,7 @@ public class ConcatV2 extends TensorFlowOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) {
+ if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
return;
}
OrderedTensorType a = inputs.get(0).type().get();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java
index 718e2a4b3c2..3c0f8569c47 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java
@@ -1,36 +1,38 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class Const extends TensorFlowOperation {
+public class Const extends IntermediateOperation {
- public Const(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ private final AttributeMap attributeMap;
+
+ public Const(String modelName,
+ String nodeName,
+ List<IntermediateOperation> inputs,
+ AttributeMap attributeMap,
+ OrderedTensorType type) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ this.type = type.rename(vespaName() + "_");
setConstantValue(value());
}
@Override
protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ return type;
}
@Override
@@ -55,7 +57,7 @@ public class Const extends TensorFlowOperation {
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
- return modelName() + "_" + super.vespaName();
+ return modelName + "_" + super.vespaName();
}
@Override
@@ -77,24 +79,11 @@ public class Const extends TensorFlowOperation {
}
private Value value() {
- if ( ! node.getAttrMap().containsKey("value")) {
- throw new IllegalArgumentException("Node '" + node.getName() + "' of type " +
- "const has missing 'value' attribute");
- }
- AttrValue attrValue = node.getAttrMap().get("value");
- if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
- return new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type().get().type()));
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.B) {
- return new BooleanValue(attrValue.getB());
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.I) {
- return new DoubleValue(attrValue.getI());
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.F) {
- return new DoubleValue(attrValue.getF());
+ Optional<Value> value = attributeMap.get("value", type);
+ if ( ! value.isPresent()) {
+ throw new IllegalArgumentException("Node '" + name + "' of type " +
+ "const has missing or non-recognized 'value' attribute");
}
- throw new IllegalArgumentException("Requesting value of constant in " +
- node.getName() + " but type is not recognized.");
+ return value.get();
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java
index 13043a61a8e..5e4abeaa234 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java
@@ -1,38 +1,34 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
import java.util.Collections;
import java.util.Optional;
-public class Constant extends OnnxOperation {
+public class Constant extends IntermediateOperation {
- final String modelName;
- final Onnx.TensorProto tensorProto;
+ private final String modelName;
- public Constant(String modelName, Onnx.TensorProto tensorProto) {
- super(null, Collections.emptyList());
+ public Constant(String modelName, String nodeName, OrderedTensorType type) {
+ super(modelName, nodeName, Collections.emptyList());
this.modelName = modelName;
- this.tensorProto = tensorProto;
+ this.type = type.rename(vespaName() + "_");
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
- return modelName + "_" + vespaName(tensorProto.getName());
+ return modelName + "_" + vespaName(name);
}
@Override
protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromOnnxType(tensorProto.getDimsList(), vespaName() + "_");
+ return type;
}
@Override
@@ -40,9 +36,14 @@ public class Constant extends OnnxOperation {
return null; // will be added by function() since this is constant.
}
+ /**
+ * Constant values are sent in via the constantValueFunction, as the
+ * dimension names and thus the data layout depends on the dimension
+ * renaming which happens after the conversion to intermediate graph.
+ */
@Override
public Optional<Value> getConstantValue() {
- return Optional.of(new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
+ return Optional.ofNullable(constantValueFunction).map(func -> func.apply(type));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java
index 2d0f4c7042b..742ed8b89ab 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java
@@ -1,9 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
@@ -12,18 +12,17 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
-public class ExpandDims extends TensorFlowOperation {
+public class ExpandDims extends IntermediateOperation {
private List<String> expandDimensions;
- public ExpandDims(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public ExpandDims(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
@@ -32,14 +31,14 @@ public class ExpandDims extends TensorFlowOperation {
return null;
}
- TensorFlowOperation axisOperation = inputs().get(1);
+ IntermediateOperation axisOperation = inputs().get(1);
if (!axisOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ExpandDims in " + name + ": " +
"axis must be a constant.");
}
Tensor axis = axisOperation.getConstantValue().get().asTensor();
if (axis.type().rank() != 0) {
- throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ExpandDims in " + name + ": " +
"axis argument must be a scalar.");
}
@@ -49,7 +48,7 @@ public class ExpandDims extends TensorFlowOperation {
dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
expandDimensions = new ArrayList<>();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : inputType.dimensions()) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java
index 1408e7e04f0..d29bd4b7a9e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java
@@ -1,22 +1,21 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
-public class Identity extends TensorFlowOperation {
+public class Identity extends IntermediateOperation {
- public Identity(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Identity(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
- return modelName() + "_" + super.vespaName();
+ return modelName + "_" + super.vespaName();
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
index 3687bba8b85..43de29cedd5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
@@ -1,17 +1,16 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Collections;
@@ -20,43 +19,40 @@ import java.util.Optional;
import java.util.function.Function;
/**
- * Wraps a TensorFlow node and produces the respective Vespa tensor operation.
- * During import, a graph of these operations are constructed. Then, the
- * types are used to deduce sensible dimension names using the
- * DimensionRenamer. After the types have been renamed, the proper
- * Vespa expressions can be extracted.
+ * Wraps an imported operation node and produces the respective Vespa tensor
+ * operation. During import, a graph of these operations are constructed. Then,
+ * the types are used to deduce sensible dimension names using the
+ * DimensionRenamer. After the types have been renamed, the proper Vespa
+ * expressions can be extracted.
*
* @author lesters
*/
-public abstract class TensorFlowOperation {
-
- protected final static String MACRO_PREFIX = "tf_macro_";
+public abstract class IntermediateOperation {
- private final String modelName;
+ private final static String MACRO_PREFIX = "imported_ml_macro_";
- protected final NodeDef node;
- protected final int port;
- protected final List<TensorFlowOperation> inputs;
- protected final List<TensorFlowOperation> outputs = new ArrayList<>();
- protected final List<String> importWarnings = new ArrayList<>();
+ protected final String name;
+ protected final String modelName;
+ protected final List<IntermediateOperation> inputs;
+ protected final List<IntermediateOperation> outputs = new ArrayList<>();
protected OrderedTensorType type;
protected TensorFunction function;
protected TensorFunction macro = null;
+ private final List<String> importWarnings = new ArrayList<>();
private Value constantValue = null;
- private List<TensorFlowOperation> controlInputs = Collections.emptyList();
+ private List<IntermediateOperation> controlInputs = Collections.emptyList();
- TensorFlowOperation(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ protected Function<OrderedTensorType, Value> constantValueFunction = null;
+
+ IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) {
+ this.name = name;
this.modelName = modelName;
- this.node = node;
- this.port = port;
this.inputs = Collections.unmodifiableList(inputs);
this.inputs.forEach(i -> i.outputs.add(this));
}
- protected String modelName() { return modelName; }
-
protected abstract OrderedTensorType lazyGetType();
protected abstract TensorFunction lazyGetFunction();
@@ -65,9 +61,6 @@ public abstract class TensorFlowOperation {
if (type == null) {
type = lazyGetType();
}
- if (type != null) {
- type.verifyType(node);
- }
return Optional.ofNullable(type);
}
@@ -87,14 +80,14 @@ public abstract class TensorFlowOperation {
return Optional.ofNullable(function);
}
- /** Return TensorFlow node */
- public NodeDef node() { return node; }
+ /** Returns original name of this operation node */
+ public String name() { return name; }
/** Return unmodifiable list of inputs */
- public List<TensorFlowOperation> inputs() { return inputs; }
+ public List<IntermediateOperation> inputs() { return inputs; }
/** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */
- public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); }
+ public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); }
/** Returns a Vespa ranking expression that should be added as a macro */
public Optional<TensorFunction> macro() { return Optional.ofNullable(macro); }
@@ -109,22 +102,34 @@ public abstract class TensorFlowOperation {
public boolean isInput() { return false; }
/** Return true if this node is constant */
- public boolean isConstant() { return inputs.stream().allMatch(TensorFlowOperation::isConstant); }
+ public boolean isConstant() { return inputs.stream().allMatch(IntermediateOperation::isConstant); }
/** Sets the constant value */
public void setConstantValue(Value value) { constantValue = value; }
/** Gets the constant value if it exists */
- public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); }
+ public Optional<Value> getConstantValue() {
+ if (constantValue != null) {
+ return Optional.of(constantValue);
+ }
+ if (constantValueFunction != null) {
+ return Optional.of(constantValueFunction.apply(type));
+ }
+ return Optional.empty();
+ }
+
+ /** Set the constant value function */
+ public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; }
/** Sets the external control inputs */
- public void setControlInputs(List<TensorFlowOperation> inputs) { this.controlInputs = inputs; }
+ public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; }
/** Retrieve the control inputs for this operation */
- public List<TensorFlowOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
+ public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
/** Retrieve the valid Vespa name of this node */
- public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; }
+ public String vespaName() { return vespaName(name); }
+ public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
/** Retrieve the valid Vespa name of this node if it is a macro */
public String macroName() { return vespaName() != null ? MACRO_PREFIX + modelName + "_" + vespaName() : null; }
@@ -135,23 +140,48 @@ public abstract class TensorFlowOperation {
/** Set an input warning */
public void warning(String warning) { importWarnings.add(warning); }
- boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) {
- if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) {
- return false;
- }
+ boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) {
if (inputs.size() != expected) {
throw new IllegalArgumentException("Expected " + expected + " inputs " +
- "for '" + node.getName() + "', got " + inputs.size());
+ "for '" + name + "', got " + inputs.size());
}
return inputs.stream().map(func).allMatch(Optional::isPresent);
}
boolean allInputTypesPresent(int expected) {
- return verifyInputs(expected, TensorFlowOperation::type);
+ return verifyInputs(expected, IntermediateOperation::type);
}
boolean allInputFunctionsPresent(int expected) {
- return verifyInputs(expected, TensorFlowOperation::function);
+ return verifyInputs(expected, IntermediateOperation::function);
+ }
+
+ /**
+ * A method signature input and output has the form name:index.
+ * This returns the name part without the index.
+ */
+ public static String namePartOf(String name) {
+ name = name.startsWith("^") ? name.substring(1) : name;
+ return name.split(":")[0];
+ }
+
+ /**
+ * This return the output index part. Indexes are used for nodes with
+ * multiple outputs.
+ */
+ public static int indexPartOf(String name) {
+ int i = name.indexOf(":");
+ return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
+ }
+
+ /**
+ * An interface mapping operation attributes to Vespa Values.
+ * Adapter for differences in ONNX/TensorFlow.
+ */
+ public interface AttributeMap {
+ Optional<Value> get(String key);
+ Optional<Value> get(String key, OrderedTensorType type);
+ Optional<List<Value>> getList(String key);
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
index fe2004a528d..8413ed74118 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
@@ -1,24 +1,22 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.List;
import java.util.function.DoubleBinaryOperator;
-public class Join extends OnnxOperation {
+public class Join extends IntermediateOperation {
private final DoubleBinaryOperator operator;
- public Join(Onnx.NodeProto node, List<OnnxOperation> inputs, DoubleBinaryOperator operator) {
- super(node, inputs);
+ public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) {
+ super(modelName, nodeName, inputs);
this.operator = operator;
}
@@ -61,8 +59,8 @@ public class Join extends OnnxOperation {
return null;
}
- OnnxOperation a = largestInput();
- OnnxOperation b = smallestInput();
+ IntermediateOperation a = largestInput();
+ IntermediateOperation b = smallestInput();
List<String> aDimensionsToReduce = new ArrayList<>();
List<String> bDimensionsToReduce = new ArrayList<>();
@@ -107,13 +105,13 @@ public class Join extends OnnxOperation {
}
}
- private OnnxOperation largestInput() {
+ private IntermediateOperation largestInput() {
OrderedTensorType a = inputs.get(0).type().get();
OrderedTensorType b = inputs.get(1).type().get();
return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
}
- private OnnxOperation smallestInput() {
+ private IntermediateOperation smallestInput() {
OrderedTensorType a = inputs.get(0).type().get();
OrderedTensorType b = inputs.get(1).type().get();
return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java
index c015f5ecba8..f54ae83052f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java
@@ -1,20 +1,19 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
import java.util.function.DoubleUnaryOperator;
-public class Map extends TensorFlowOperation {
+public class Map extends IntermediateOperation {
private final DoubleUnaryOperator operator;
- public Map(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleUnaryOperator operator) {
- super(modelName, node, inputs, port);
+ public Map(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) {
+ super(modelName, nodeName, inputs);
this.operator = operator;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java
index 1b388e2ae89..52e223f9518 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java
@@ -1,21 +1,18 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
-import java.util.Collections;
import java.util.List;
import java.util.Optional;
-import java.util.function.DoubleBinaryOperator;
-public class MatMul extends OnnxOperation {
+public class MatMul extends IntermediateOperation {
- public MatMul(Onnx.NodeProto node, List<OnnxOperation> inputs) {
- super(node, inputs);
+ public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
index 3eba872c6a0..95a77c07590 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
@@ -1,9 +1,10 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
@@ -13,20 +14,20 @@ import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
-public class Mean extends TensorFlowOperation {
+public class Mean extends IntermediateOperation {
+ private final AttributeMap attributeMap;
private List<String> reduceDimensions;
- public Mean(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Mean(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
}
@Override
@@ -34,9 +35,9 @@ public class Mean extends TensorFlowOperation {
if (!allInputTypesPresent(2)) {
return null;
}
- TensorFlowOperation reductionIndices = inputs.get(1);
+ IntermediateOperation reductionIndices = inputs.get(1);
if (!reductionIndices.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Mean in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Mean in " + name + ": " +
"reduction indices must be a constant.");
}
Tensor indices = reductionIndices.getConstantValue().get().asTensor();
@@ -54,7 +55,7 @@ public class Mean extends TensorFlowOperation {
return reducedType(inputType, shouldKeepDimensions());
}
- // todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity.
+ // optimization: if keepDims and one reduce dimension that has size 1: same as identity.
@Override
protected TensorFunction lazyGetFunction() {
@@ -93,12 +94,12 @@ public class Mean extends TensorFlowOperation {
}
private boolean shouldKeepDimensions() {
- AttrValue keepDimsAttr = node.getAttrMap().get("keep_dims");
- return keepDimsAttr != null && keepDimsAttr.getB();
+ Optional<Value> keepDims = attributeMap.get("keep_dims");
+ return keepDims.isPresent() && keepDims.get().asBoolean();
}
private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
if (!reduceDimensions.contains(dimension.name())) {
builder.add(dimension);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java
index 4c95e67e184..9d9eca47b1c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java
@@ -1,21 +1,20 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
-public class Merge extends TensorFlowOperation {
+public class Merge extends IntermediateOperation {
- public Merge(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Merge(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
protected OrderedTensorType lazyGetType() {
- for (TensorFlowOperation operation : inputs) {
+ for (IntermediateOperation operation : inputs) {
if (operation.type().isPresent()) {
return operation.type().get();
}
@@ -25,7 +24,7 @@ public class Merge extends TensorFlowOperation {
@Override
protected TensorFunction lazyGetFunction() {
- for (TensorFlowOperation operation : inputs) {
+ for (IntermediateOperation operation : inputs) {
if (operation.function().isPresent()) {
return operation.function().get();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java
new file mode 100644
index 00000000000..19ba146492c
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java
@@ -0,0 +1,26 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.Collections;
+import java.util.List;
+
+public class NoOp extends IntermediateOperation {
+
+ public NoOp(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, Collections.emptyList()); // don't propagate inputs
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return null;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java
index 65ce7f00e34..9299ae9be12 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java
@@ -1,17 +1,16 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class PlaceholderWithDefault extends TensorFlowOperation {
+public class PlaceholderWithDefault extends IntermediateOperation {
- public PlaceholderWithDefault(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public PlaceholderWithDefault(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
index e7d90e5fc1f..e91c2305f7d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
@@ -1,10 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
@@ -19,19 +18,18 @@ import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
-import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize;
+import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize;
-public class Reshape extends TensorFlowOperation {
+public class Reshape extends IntermediateOperation {
- public Reshape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
@@ -39,15 +37,15 @@ public class Reshape extends TensorFlowOperation {
if (!allInputTypesPresent(2)) {
return null;
}
- TensorFlowOperation newShape = inputs.get(1);
+ IntermediateOperation newShape = inputs.get(1);
if (!newShape.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Reshape in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Reshape in " + name + ": " +
"shape input must be a constant.");
}
Tensor shape = newShape.getConstantValue().get().asTensor();
OrderedTensorType inputType = inputs.get(0).type().get();
- OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder();
int dimensionIndex = 0;
for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
Tensor.Cell cell = cellIterator.next();
@@ -124,7 +122,7 @@ public class Reshape extends TensorFlowOperation {
operators.add(0, ArithmeticOperator.MULTIPLY);
children.add(0, new ConstantNode(new DoubleValue(size)));
}
- size *= TensorConverter.dimensionSize(dimension);
+ size *= OrderedTensorType.dimensionSize(dimension);
if (i > 0) {
operators.add(0, ArithmeticOperator.PLUS);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java
index 5fdcb5a695f..927a4a368f9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java
@@ -1,24 +1,23 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.function.DoubleBinaryOperator;
-import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.dimensionSize;
-import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize;
+import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.dimensionSize;
+import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize;
-public class Select extends TensorFlowOperation {
+public class Select extends IntermediateOperation {
- public Select(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
@@ -39,7 +38,7 @@ public class Select extends TensorFlowOperation {
if (!allInputFunctionsPresent(3)) {
return null;
}
- TensorFlowOperation conditionOperation = inputs().get(0);
+ IntermediateOperation conditionOperation = inputs().get(0);
TensorFunction a = inputs().get(1).function().get();
TensorFunction b = inputs().get(2).function().get();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java
index af49d2c108b..da566909adc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java
@@ -1,20 +1,19 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
-public class Shape extends TensorFlowOperation {
+public class Shape extends IntermediateOperation {
- public Shape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Shape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
createConstantValue();
}
@@ -24,7 +23,7 @@ public class Shape extends TensorFlowOperation {
return null;
}
OrderedTensorType inputType = inputs.get(0).type().get();
- return new OrderedTensorType.Builder(node)
+ return new OrderedTensorType.Builder()
.add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size()))
.build();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java
index 17ce9e8b7cb..c750c47e27e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java
@@ -1,26 +1,26 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
-public class Squeeze extends TensorFlowOperation {
+public class Squeeze extends IntermediateOperation {
+ private final AttributeMap attributeMap;
private List<String> squeezeDimensions;
- public Squeeze(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Squeeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
}
@Override
@@ -31,20 +31,21 @@ public class Squeeze extends TensorFlowOperation {
OrderedTensorType inputType = inputs.get(0).type().get();
squeezeDimensions = new ArrayList<>();
- AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims");
- if (squeezeDimsAttr == null) {
+ Optional<List<Value>> squeezeDimsAttr = attributeMap.getList("squeeze_dims");
+ if ( ! squeezeDimsAttr.isPresent()) {
squeezeDimensions = inputType.type().dimensions().stream().
- filter(dim -> TensorConverter.dimensionSize(dim) == 1).
+ filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
map(TensorType.Dimension::name).
collect(Collectors.toList());
} else {
- squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().
+ squeezeDimensions = squeezeDimsAttr.get().stream().map(Value::asDouble).map(Double::intValue).
map(i -> i < 0 ? inputType.type().dimensions().size() - i : i).
- map(i -> inputType.type().dimensions().get(i.intValue())).
- filter(dim -> TensorConverter.dimensionSize(dim) == 1).
+ map(i -> inputType.type().dimensions().get(i)).
+ filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
map(TensorType.Dimension::name).
collect(Collectors.toList());
}
+
return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType);
}
@@ -72,7 +73,7 @@ public class Squeeze extends TensorFlowOperation {
}
private OrderedTensorType reducedType(OrderedTensorType inputType) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
if ( ! squeezeDimensions.contains(dimension.name())) {
builder.add(dimension);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java
index de4d8862fd6..0171d1ea171 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java
@@ -1,17 +1,19 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class Switch extends TensorFlowOperation {
+public class Switch extends IntermediateOperation {
- public Switch(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ private final int port;
+
+ public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) {
+ super(modelName, nodeName, inputs);
+ this.port = port;
}
@Override
@@ -21,7 +23,7 @@ public class Switch extends TensorFlowOperation {
}
Optional<OrderedTensorType> predicate = inputs.get(1).type();
if (predicate.get().type().rank() != 0) {
- throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Switch in " + name + ": " +
"predicate must be a scalar");
}
return inputs.get(0).type().orElse(null);
@@ -29,13 +31,13 @@ public class Switch extends TensorFlowOperation {
@Override
protected TensorFunction lazyGetFunction() {
- TensorFlowOperation predicateOperation = inputs().get(1);
+ IntermediateOperation predicateOperation = inputs().get(1);
if (!predicateOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Switch in " + name + ": " +
"predicate must be a constant");
}
if (port < 0 || port > 1) {
- throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Switch in " + name + ": " +
"choice should be boolean");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
new file mode 100644
index 00000000000..a815cbc3944
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
@@ -0,0 +1,85 @@
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * Converts TensorFlow node attributes to Vespa attribute values.
+ *
+ * @author lesters
+ */
+public class AttributeConverter implements IntermediateOperation.AttributeMap {
+
+ private final Map<String, AttrValue> attributeMap;
+
+ public AttributeConverter(NodeDef node) {
+ attributeMap = node.getAttrMap();
+ }
+
+ public static AttributeConverter convert(NodeDef node) {
+ return new AttributeConverter(node);
+ }
+
+ @Override
+ public Optional<Value> get(String key) {
+ if (attributeMap.containsKey(key)) {
+ AttrValue attrValue = attributeMap.get(key);
+ if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
+ return Optional.empty(); // requires type
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.B) {
+ return Optional.of(new BooleanValue(attrValue.getB()));
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.I) {
+ return Optional.of(new DoubleValue(attrValue.getI()));
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.F) {
+ return Optional.of(new DoubleValue(attrValue.getF()));
+ }
+ }
+ return Optional.empty();
+ }
+
+ @Override
+ public Optional<Value> get(String key, OrderedTensorType type) {
+ if (attributeMap.containsKey(key)) {
+ AttrValue attrValue = attributeMap.get(key);
+ if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
+ return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type.type())));
+ }
+ }
+ return get(key);
+ }
+
+ @Override
+ public Optional<List<Value>> getList(String key) {
+ if (attributeMap.containsKey(key)) {
+ AttrValue attrValue = attributeMap.get(key);
+ if (attrValue.getValueCase() == AttrValue.ValueCase.LIST) {
+ AttrValue.ListValue listValue = attrValue.getList();
+ if ( ! listValue.getBList().isEmpty()) {
+ return Optional.of(listValue.getBList().stream().map(BooleanValue::new).collect(Collectors.toList()));
+ }
+ if ( ! listValue.getIList().isEmpty()) {
+ return Optional.of(listValue.getIList().stream().map(DoubleValue::new).collect(Collectors.toList()));
+ }
+ if ( ! listValue.getFList().isEmpty()) {
+ return Optional.of(listValue.getFList().stream().map(DoubleValue::new).collect(Collectors.toList()));
+ }
+ // add the rest
+ }
+ }
+ return Optional.empty();
+ }
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
new file mode 100644
index 00000000000..e1b292f9e61
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
@@ -0,0 +1,234 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Const;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ExpandDims;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Mean;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Merge;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.PlaceholderWithDefault;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Select;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Squeeze;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Switch;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.framework.GraphDef;
+import org.tensorflow.framework.MetaGraphDef;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.SignatureDef;
+import org.tensorflow.framework.TensorInfo;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Converts a TensorFlow graph to a Vespa IntermediateGraph which is the basis
+ * for generating Vespa ranking expressions.
+ *
+ * @author lesters
+ */
+public class GraphImporter {
+
+ public static IntermediateOperation mapOperation(NodeDef node,
+ List<IntermediateOperation> inputs,
+ IntermediateGraph graph) {
+ String nodeName = node.getName();
+ String modelName = graph.name();
+ int nodePort = IntermediateOperation.indexPartOf(nodeName);
+ OrderedTensorType nodeType = TypeConverter.fromTensorFlowType(node);
+ AttributeConverter attributes = AttributeConverter.convert(node);
+
+ switch (node.getOp().toLowerCase()) {
+ // array ops
+ case "concatv2": return new ConcatV2(modelName, nodeName, inputs);
+ case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType);
+ case "expanddims": return new ExpandDims(modelName, nodeName, inputs);
+ case "identity": return new Identity(modelName, nodeName, inputs);
+ case "placeholder": return new Argument(modelName, nodeName, nodeType);
+ case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs);
+ case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "shape": return new Shape(modelName, nodeName, inputs);
+ case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
+
+ // control flow
+ case "merge": return new Merge(modelName, nodeName, inputs);
+ case "switch": return new Switch(modelName, nodeName, inputs, nodePort);
+
+ // math ops
+ case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "matmul": return new MatMul(modelName, nodeName, inputs);
+ case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
+ case "mean": return new Mean(modelName, nodeName, inputs, attributes);
+ case "reducemean": return new Mean(modelName, nodeName, inputs, attributes);
+ case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt());
+ case "select": return new Select(modelName, nodeName, inputs);
+ case "where3": return new Select(modelName, nodeName, inputs);
+ case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
+ case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference());
+ case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+
+ // nn ops
+ case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
+ case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
+ case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
+
+ // state ops
+ case "variable": return new Constant(modelName, nodeName, nodeType);
+ case "variablev2": return new Constant(modelName, nodeName, nodeType);
+
+ // evaluation no-ops
+ case "stopgradient":return new Identity(modelName, nodeName, inputs);
+ case "noop": return new NoOp(modelName, nodeName, inputs);
+
+ }
+
+ IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
+ op.warning("Operation '" + node.getOp() + "' is currently not implemented");
+ return op;
+ }
+
+ public static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException {
+ MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef());
+
+ IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
+ importSignatures(tfGraph, intermediateGraph);
+ importOperations(tfGraph, intermediateGraph, bundle);
+ verifyOutputTypes(tfGraph, intermediateGraph);
+
+ return intermediateGraph;
+ }
+
+ private static void importSignatures(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) {
+ for (java.util.Map.Entry<String, SignatureDef> signatureEntry : tfGraph.getSignatureDefMap().entrySet()) {
+ String signatureName = signatureEntry.getKey();
+ java.util.Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap();
+ for (java.util.Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) {
+ String inputName = input.getKey();
+ String nodeName = input.getValue().getName();
+ intermediateGraph.inputs(signatureName).put(inputName, IntermediateOperation.namePartOf(nodeName));
+ }
+ java.util.Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap();
+ for (java.util.Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) {
+ String outputName = output.getKey();
+ String nodeName = output.getValue().getName();
+ intermediateGraph.outputs(signatureName).put(outputName, IntermediateOperation.namePartOf(nodeName));
+ }
+ }
+ }
+
+ private static void importOperations(MetaGraphDef tfGraph,
+ IntermediateGraph intermediateGraph,
+ SavedModelBundle bundle) {
+ for (String signatureName : intermediateGraph.signatures()) {
+ for (String outputName : intermediateGraph.outputs(signatureName).values()) {
+ importOperation(outputName, tfGraph.getGraphDef(), intermediateGraph, bundle);
+ }
+ }
+ }
+
+ private static IntermediateOperation importOperation(String nodeName,
+ GraphDef tfGraph,
+ IntermediateGraph intermediateGraph,
+ SavedModelBundle bundle) {
+ if (intermediateGraph.alreadyImported(nodeName)) {
+ return intermediateGraph.get(nodeName);
+ }
+ NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(nodeName), tfGraph);
+ List<IntermediateOperation> inputs = importOperationInputs(node, tfGraph, intermediateGraph, bundle);
+ IntermediateOperation operation = mapOperation(node, inputs, intermediateGraph);
+ intermediateGraph.put(nodeName, operation);
+
+ List<IntermediateOperation> controlInputs = importControlInputs(node, tfGraph, intermediateGraph, bundle);
+ if (controlInputs.size() > 0) {
+ operation.setControlInputs(controlInputs);
+ }
+
+ if (operation.isConstant()) {
+ operation.setConstantValueFunction(
+ type -> new TensorValue(TensorConverter.toVespaTensor(readVariable(nodeName, bundle), type)));
+ }
+
+ return operation;
+ }
+
+ private static List<IntermediateOperation> importOperationInputs(NodeDef node,
+ GraphDef tfGraph,
+ IntermediateGraph intermediateGraph,
+ SavedModelBundle bundle) {
+ return node.getInputList().stream()
+ .filter(name -> ! isControlDependency(name))
+ .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
+ .collect(Collectors.toList());
+ }
+
+ private static List<IntermediateOperation> importControlInputs(NodeDef node,
+ GraphDef tfGraph,
+ IntermediateGraph intermediateGraph,
+ SavedModelBundle bundle) {
+ return node.getInputList().stream()
+ .filter(nodeName -> isControlDependency(nodeName))
+ .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
+ .collect(Collectors.toList());
+ }
+
+ private static boolean isControlDependency(String name) {
+ return name.startsWith("^");
+ }
+
+ private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef tfGraph) {
+ for (NodeDef node : tfGraph.getNodeList()) {
+ if (node.getName().equals(name)) {
+ return node;
+ }
+ }
+ throw new IllegalArgumentException("Could not find node '" + name + "'");
+ }
+
+ public static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
+ Session.Runner fetched = bundle.session().runner().fetch(name);
+ List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
+ if (importedTensors.size() != 1)
+ throw new IllegalStateException("Expected 1 tensor from fetching " + name +
+ ", but got " + importedTensors.size());
+ return importedTensors.get(0);
+ }
+
+ private static void verifyOutputTypes(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) {
+ for (String signatureName : intermediateGraph.signatures()) {
+ for (String outputName : intermediateGraph.outputs(signatureName).values()) {
+ IntermediateOperation operation = intermediateGraph.get(outputName);
+ NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(operation.name()), tfGraph.getGraphDef());
+ OrderedTensorType type = operation.type().orElseThrow(
+ () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."));
+ TypeConverter.verifyType(node, type);
+ }
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java
index 3f55e622fdf..d2d0acfc964 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
new file mode 100644
index 00000000000..67ad1edc312
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
@@ -0,0 +1,72 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.TensorShapeProto;
+
+import java.util.List;
+
+/**
+ * Converts and verifies TensorFlow tensor types into Vespa tensor types.
+ *
+ * @author lesters
+ */
+public class TypeConverter {
+
+ public static void verifyType(NodeDef node, OrderedTensorType type) {
+ TensorShapeProto shape = tensorFlowShape(node);
+ if (shape != null) {
+ if (shape.getDimCount() != type.rank()) {
+ throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
+ "does not match Vespa shape");
+ }
+ for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) {
+ int vespaIndex = type.dimensionMap(tensorFlowIndex);
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
+ TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
+ if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
+ throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
+ "does not match Vespa dimensions");
+ }
+ }
+ }
+ }
+
+ private static TensorShapeProto tensorFlowShape(NodeDef node) {
+ AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
+ if (attrValueList == null) {
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "does not exist");
+ }
+ if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "is not of expected type");
+ }
+ List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
+ return shapeList.get(0); // support multiple outputs?
+ }
+
+ public static OrderedTensorType fromTensorFlowType(NodeDef node) {
+ return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ TensorShapeProto shape = tensorFlowShape(node);
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
+ if (tensorFlowDimension.getSize() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java
index 5cff8b03d40..1530754cc43 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java
@@ -3,6 +3,6 @@
* ONNX integration
*/
@ExportPackage
-package com.yahoo.searchlib.rankingexpression.integration.onnx;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
deleted file mode 100644
index fa1f929cc80..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
+++ /dev/null
@@ -1,326 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.onnx;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Constant;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Argument;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OperationMapper;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.yolean.Exceptions;
-import onnx.Onnx;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.logging.Logger;
-import java.util.stream.Collectors;
-
-/**
- * Converts a ONNX model into a ranking expression and set of constants.
- *
- * @author lesters
- */
-public class OnnxImporter {
-
- private static final Logger log = Logger.getLogger(OnnxImporter.class.getName());
-
- public OnnxModel importModel(String modelName, File modelDir) {
- return importModel(modelName, modelDir.toString());
- }
-
- public OnnxModel importModel(String modelName, String modelPath) {
- try (FileInputStream inputStream = new FileInputStream(modelPath)) {
- Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
- return importModel(modelName, model);
- } catch (IOException e) {
- throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
- }
- }
-
- public OnnxModel importModel(String modelName, Onnx.ModelProto model) {
- return importGraph(modelName, model.getGraph());
- }
-
- private static OnnxModel importGraph(String modelName, Onnx.GraphProto graph) {
- OnnxModel model = new OnnxModel(modelName);
- OperationIndex index = new OperationIndex();
-
- importNodes(graph, model, index);
- verifyOutputTypes(graph, model, index);
- findDimensionNames(model, index);
- importExpressions(model, index);
-
- reportWarnings(model, index);
-
- return model;
- }
-
- private static void importNodes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
- for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
- importNode(valueInfo.getName(), graph, model, index);
- }
- }
-
- private static OnnxOperation importNode(String name, Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
- if (index.alreadyImported(name)) {
- return index.get(name);
- }
- OnnxOperation operation;
- if (isArgumentTensor(name, graph)) {
- operation = new Argument(getArgumentTensor(name, graph));
- model.input(OnnxOperation.namePartOf(name), operation.vespaName());
- } else if (isConstantTensor(name, graph)) {
- operation = new Constant(model.name(), getConstantTensor(name, graph));
- } else {
- Onnx.NodeProto node = getNodeFromGraph(name, graph);
- List<OnnxOperation> inputs = importNodeInputs(node, graph, model, index);
- operation = OperationMapper.get(node, inputs);
- if (isOutputNode(name, graph)) {
- model.output(OnnxOperation.namePartOf(name), operation.vespaName());
- }
- }
- index.put(operation.vespaName(), operation);
-
- return operation;
- }
-
- private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
- Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
- Onnx.TensorProto tensor = getConstantTensor(name, graph);
- return value != null && tensor == null;
- }
-
- private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
- Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
- Onnx.TensorProto tensor = getConstantTensor(name, graph);
- return value != null && tensor != null;
- }
-
- private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
- for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) {
- if (valueInfo.getName().equals(name)) {
- return valueInfo;
- }
- }
- return null;
- }
-
- private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) {
- for (Onnx.TensorProto tensorProto : graph.getInitializerList()) {
- if (tensorProto.getName().equals(name)) {
- return tensorProto;
- }
- }
- return null;
- }
-
- private static boolean isOutputNode(String name, Onnx.GraphProto graph) {
- return getOutputNode(name, graph) != null;
- }
-
- private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
- for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
- if (valueInfo.getName().equals(name)) {
- return valueInfo;
- }
- String nodeName = OnnxOperation.namePartOf(valueInfo.getName());
- if (nodeName.equals(name)) {
- return valueInfo;
- }
- }
- return null;
- }
-
- private static List<OnnxOperation> importNodeInputs(Onnx.NodeProto node,
- Onnx.GraphProto graph,
- OnnxModel model,
- OperationIndex index) {
- return node.getInputList().stream()
- .map(nodeName -> importNode(nodeName, graph, model, index))
- .collect(Collectors.toList());
- }
-
- private static void verifyOutputTypes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
- for (String outputName : model.outputs().values()) {
- OnnxOperation operation = index.get(outputName);
- Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, graph);
- operation.type().orElseThrow(
- () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."))
- .verifyType(onnxNode.getType());
- }
- }
-
-
- /** Find dimension names to avoid excessive renaming while evaluating the model. */
- private static void findDimensionNames(OnnxModel model, OperationIndex index) {
- DimensionRenamer renamer = new DimensionRenamer();
- for (String output : model.outputs().values()) {
- addDimensionNameConstraints(index.get(output), renamer);
- }
- renamer.solve();
- for (String output : model.outputs().values()) {
- renameDimensions(index.get(output), renamer);
- }
- }
-
- private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
- operation.addDimensionNameConstraints(renamer);
- }
- }
-
- private static void renameDimensions(OnnxOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> renameDimensions(input, renamer));
- operation.renameDimensions(renamer);
- }
- }
-
- private static void importExpressions(OnnxModel model, OperationIndex index) {
- for (String outputName : model.outputs().values()) {
- try {
- Optional<TensorFunction> function = importExpression(index.get(outputName), model);
- if (!function.isPresent()) {
- model.skippedOutput(outputName, "No valid output function could be found.");
- }
- }
- catch (IllegalArgumentException e) {
- model.skippedOutput(outputName, Exceptions.toMessageString(e));
- }
- }
- }
-
- private static Optional<TensorFunction> importExpression(OnnxOperation operation, OnnxModel model) {
- if (!operation.type().isPresent()) {
- return Optional.empty();
- }
- if (operation.isConstant()) {
- return importConstant(operation, model);
- }
- importInputExpressions(operation, model);
- importRankingExpression(operation, model);
- importArgumentExpression(operation, model);
-
- return operation.function();
- }
-
- private static void importInputExpressions(OnnxOperation operation, OnnxModel model) {
- operation.inputs().forEach(input -> importExpression(input, model));
- }
-
- private static Optional<TensorFunction> importConstant(OnnxOperation operation, OnnxModel model) {
- String name = operation.vespaName();
- if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
- return operation.function();
- }
-
- Value value = operation.getConstantValue().orElseThrow(() ->
- new IllegalArgumentException("Operation '" + operation.vespaName() + "' " +
- "is constant but does not have a value."));
- if ( ! (value instanceof TensorValue)) {
- return operation.function(); // scalar values are inserted directly into the expression
- }
-
- Tensor tensor = value.asTensor();
- if (tensor.type().rank() == 0) {
- model.smallConstant(name, tensor);
- } else {
- model.largeConstant(name, tensor);
- }
- return operation.function();
- }
-
- private static void importRankingExpression(OnnxOperation operation, OnnxModel model) {
- if (operation.function().isPresent()) {
- String name = operation.vespaName();
- if (!model.expressions().containsKey(name)) {
- TensorFunction function = operation.function().get();
-
- if (model.outputs().containsKey(name)) {
- OrderedTensorType operationType = operation.type().get();
- OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
- if ( ! operationType.equals(standardNamingType)) {
- List<String> renameFrom = operationType.dimensionNames();
- List<String> renameTo = standardNamingType.dimensionNames();
- function = new Rename(function, renameFrom, renameTo);
- }
- }
-
- try {
- // We add all intermediate nodes imported as separate expressions. Only
- // those referenced from the output will be used. We parse the
- // TensorFunction here to convert it to a RankingExpression tree.
- model.expression(name, new RankingExpression(name, function.toString()));
- }
- catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function +
- " cannot be parsed as a ranking expression", e);
- }
- }
- }
- }
-
- private static void importArgumentExpression(OnnxOperation operation, OnnxModel model) {
- if (operation.isInput()) {
- // All inputs must have dimensions with standard naming convention: d0, d1, ...
- OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
- model.argument(operation.vespaName(), standardNamingConvention.type());
- model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
- }
- }
-
- private static void reportWarnings(OnnxModel model, OperationIndex index) {
- for (String output : model.outputs().values()) {
- reportWarnings(model, index.get(output));
- }
- }
-
- private static void reportWarnings(OnnxModel model, OnnxOperation operation) {
- for (String warning : operation.warnings()) {
- model.importWarning(warning);
- }
- for (OnnxOperation input : operation.inputs()) {
- reportWarnings(model, input);
- }
- }
-
- private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
- boolean hasPortNumber = nodeName.contains(":");
- for (Onnx.NodeProto node : graph.getNodeList()) {
- if (hasPortNumber) {
- for (String outputName : node.getOutputList()) {
- if (outputName.equals(nodeName)) {
- return node;
- }
- }
- } else if (node.getName().equals(nodeName)) {
- return node;
- }
- }
- throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
- }
-
- private static class OperationIndex {
- private final Map<String, OnnxOperation> index = new HashMap<>();
- public OnnxOperation put(String key, OnnxOperation operation) { return index.put(key, operation); }
- public OnnxOperation get(String key) { return index.get(key); }
- public boolean alreadyImported(String key) { return index.containsKey(key); }
- public Collection<OnnxOperation> operations() { return index.values(); }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
deleted file mode 100644
index bd53afefc3f..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
+++ /dev/null
@@ -1,112 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.onnx;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.regex.Pattern;
-
-/**
- * The result of importing an ONNX model into Vespa.
- *
- * @author bratseth
- * @author lesters
- */
-public class OnnxModel {
-
- private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
-
- private final String name;
-
- public OnnxModel(String name) {
- if ( ! nameRegexp.matcher(name).matches())
- throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" +
- name + "'");
- this.name = name;
- }
-
- /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
- public String name() { return name; }
-
- private final Map<String, String> inputs = new HashMap<>();
- private final Map<String, String> outputs = new HashMap<>();
- private final Map<String, String> skippedOutputs = new HashMap<>();
- private final List<String> importWarnings = new ArrayList<>();
-
- private final Map<String, TensorType> arguments = new HashMap<>();
- private final Map<String, Tensor> smallConstants = new HashMap<>();
- private final Map<String, Tensor> largeConstants = new HashMap<>();
- private final Map<String, RankingExpression> expressions = new HashMap<>();
- private final Map<String, RankingExpression> macros = new HashMap<>();
- private final Map<String, TensorType> requiredMacros = new HashMap<>();
-
- void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
- void output(String name, String expressionName) { outputs.put(name, expressionName); }
- void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
- void importWarning(String warning) { importWarnings.add(warning); }
- void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
- void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
- void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
- void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- void macro(String name, RankingExpression expression) { macros.put(name, expression); }
- void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
-
- /**
- * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
- * to argument (Placeholder) name in the owner of this
- */
- public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
-
- /** Returns arguments().get(inputs.get(name)), e.g the type of the argument this input references */
- public TensorType inputArgument(String inputName) { return arguments().get(inputs.get(inputName)); }
-
- /** Returns an immutable list of the expression names of this */
- public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
-
- /**
- * Returns an immutable list of the outputs of this which could not be imported,
- * with a string detailing the reason for each
- */
- public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
-
- /**
- * Returns an immutable list of possibly non-fatal warnings encountered during import.
- */
- public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
-
- /** Returns expressions().get(outputs.get(outputName)), e.g the expression this output references */
- public RankingExpression outputExpression(String outputName) { return expressions().get(outputs.get(outputName)); }
-
- /** Returns an immutable map of the arguments (inputs) of this */
- public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
-
- /**
- * Returns an immutable map of the small constants of this.
- */
- public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
-
- /**
- * Returns an immutable map of the large constants of this.
- */
- public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
-
- /**
- * Returns an immutable map of the expressions of this - corresponding to ONNX nodes
- * which are not inputs or constants.
- */
- public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
-
- /** Returns an immutable map of macros that are part of this model */
- public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); }
-
- /** Returns an immutable map of the macros that must be provided by the environment running this model */
- public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java
deleted file mode 100644
index 12090145d3a..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
-
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Join;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.MatMul;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.NoOp;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import onnx.Onnx;
-
-import java.util.List;
-
-public class OperationMapper {
-
- public static OnnxOperation get(Onnx.NodeProto node, List<OnnxOperation> inputs) {
- switch (node.getOpType().toLowerCase()) {
- case "add": return new Join(node, inputs, ScalarFunctions.add());
- case "matmul": return new MatMul(node, inputs);
- }
-
- OnnxOperation op = new NoOp(node, inputs);
- op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
- return op;
- }
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java
deleted file mode 100644
index a8d8d63daf4..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java
+++ /dev/null
@@ -1,64 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.VariableTensor;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
-
-import java.util.Collections;
-import java.util.List;
-
-public class Argument extends OnnxOperation {
-
- private Onnx.ValueInfoProto valueInfo;
- private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
-
- public Argument(Onnx.ValueInfoProto valueInfoProto) {
- super(null, Collections.emptyList());
- valueInfo = valueInfoProto;
- standardNamingType = OrderedTensorType.fromOnnxType(valueInfo.getType());
- }
-
- @Override
- public String vespaName() {
- return vespaName(valueInfo.getName());
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromOnnxType(valueInfo.getType(), vespaName() + "_");
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
- if (!standardNamingType.equals(type)) {
- List<String> renameFrom = standardNamingType.dimensionNames();
- List<String> renameTo = type.dimensionNames();
- output = new Rename(output, renameFrom, renameTo);
- }
- return output;
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
- }
-
- @Override
- public boolean isInput() {
- return true;
- }
-
- @Override
- public boolean isConstant() {
- return false;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java
deleted file mode 100644
index b1136a0ce0a..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
-
-import java.util.Collections;
-import java.util.List;
-
-public class NoOp extends OnnxOperation {
-
- public NoOp(Onnx.NodeProto node, List<OnnxOperation> inputs) {
- super(node, Collections.emptyList()); // don't propagate inputs
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return null;
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- return null;
- }
-
- @Override
- public boolean isConstant() {
- return true;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
deleted file mode 100644
index 30f7b4f4711..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
+++ /dev/null
@@ -1,139 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Optional;
-import java.util.function.Function;
-
-/**
- * Wraps an ONNX node and produces the respective Vespa tensor operation.
- * During import, a graph of these operations are constructed. Then, the
- * types are used to deduce sensible dimension names using the
- * DimensionRenamer. After the types have been renamed, the proper
- * Vespa expressions can be extracted.
- *
- * @author lesters
- */
-public abstract class OnnxOperation {
-
- protected final Onnx.NodeProto node; // can be null for onnx inputs and constants
- protected final List<OnnxOperation> inputs;
- protected final List<OnnxOperation> outputs = new ArrayList<>();
- protected final List<String> importWarnings = new ArrayList<>();
-
- protected OrderedTensorType type;
- protected TensorFunction function;
- protected Value constantValue = null;
-
- OnnxOperation(Onnx.NodeProto node, List<OnnxOperation> inputs) {
- this.node = node;
- this.inputs = Collections.unmodifiableList(inputs);
- this.inputs.forEach(i -> i.outputs.add(this));
- }
-
- protected abstract OrderedTensorType lazyGetType();
- protected abstract TensorFunction lazyGetFunction();
-
- /** Returns the Vespa tensor type of this operation if it exists */
- public Optional<OrderedTensorType> type() {
- if (type == null) {
- type = lazyGetType();
- }
- return Optional.ofNullable(type);
- }
-
- /** Returns the Vespa tensor function implementing all operations from this node with inputs */
- public Optional<TensorFunction> function() {
- if (function == null) {
- if (isConstant()) {
- ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
- function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
- } else {
- function = lazyGetFunction();
- }
- }
- return Optional.ofNullable(function);
- }
-
- /** Return Onnx node */
- public Onnx.NodeProto node() { return node; }
-
- /** Return unmodifiable list of inputs */
- public List<OnnxOperation> inputs() { return inputs; }
-
- /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */
- public List<OnnxOperation> outputs() { return Collections.unmodifiableList(outputs); }
-
- /** Add dimension name constraints for this operation */
- public void addDimensionNameConstraints(DimensionRenamer renamer) { }
-
- /** Performs dimension rename for this operation */
- public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
-
- /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
- public boolean isInput() { return false; }
-
- /** Return true if this node is constant */
- public boolean isConstant() { return inputs.stream().allMatch(OnnxOperation::isConstant); }
-
- /** Gets the constant value if it exists */
- public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); }
-
- /** Retrieve the valid Vespa name of this node */
- public String vespaName() { return vespaName(node.getName()); }
- public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
-
- /** Retrieve the list of warnings produced during its lifetime */
- public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
-
- /** Set an input warning */
- public void warning(String warning) { importWarnings.add(warning); }
-
- boolean verifyInputs(int expected, Function<OnnxOperation, Optional<?>> func) {
- if (inputs.size() != expected) {
- throw new IllegalArgumentException("Expected " + expected + " inputs " +
- "for '" + node.getName() + "', got " + inputs.size());
- }
- return inputs.stream().map(func).allMatch(Optional::isPresent);
- }
-
- boolean allInputTypesPresent(int expected) {
- return verifyInputs(expected, OnnxOperation::type);
- }
-
- boolean allInputFunctionsPresent(int expected) {
- return verifyInputs(expected, OnnxOperation::function);
- }
-
- /**
- * A method signature input and output has the form name:index.
- * This returns the name part without the index.
- */
- public static String namePartOf(String name) {
- name = name.startsWith("^") ? name.substring(1) : name;
- return name.split(":")[0];
- }
-
- /**
- * This return the output index part. Indexes are used for nodes with
- * multiple outputs.
- */
- public static int indexPartOf(String name) {
- int i = name.indexOf(":");
- return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
deleted file mode 100644
index e3c72830095..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ /dev/null
@@ -1,411 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.yolean.Exceptions;
-import org.tensorflow.SavedModelBundle;
-import org.tensorflow.Session;
-import org.tensorflow.framework.GraphDef;
-import org.tensorflow.framework.MetaGraphDef;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.SignatureDef;
-import org.tensorflow.framework.TensorInfo;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.logging.Logger;
-import java.util.stream.Collectors;
-
-/**
- * Converts a saved TensorFlow model into a ranking expression and set of constants.
- *
- * @author bratseth
- * @author lesters
- */
-public class TensorFlowImporter {
-
- private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName());
-
- /**
- * Imports a saved TensorFlow model from a directory.
- * The model should be saved as a .pbtxt or .pb file.
- * The name of the model is taken as the db/pbtxt file name (not including the file ending).
- *
- * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
- * @param modelDir the directory containing the TensorFlow model files to import
- */
- public TensorFlowModel importModel(String modelName, String modelDir) {
- try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
-
- return importModel(modelName, model);
- }
- catch (IllegalArgumentException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
- }
- }
-
- public TensorFlowModel importModel(String modelName, File modelDir) {
- return importModel(modelName, modelDir.toString());
- }
-
- /** Imports a TensorFlow model */
- public TensorFlowModel importModel(String modelName, SavedModelBundle model) {
- try {
- return importGraph(modelName, MetaGraphDef.parseFrom(model.metaGraphDef()), model);
- }
- catch (IOException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
- }
- }
-
- /**
- * Imports the TensorFlow graph by first importing the tensor types, then
- * finding a suitable set of dimensions names for each
- * placeholder/constant/variable, then importing the expressions.
- */
- private static TensorFlowModel importGraph(String modelName, MetaGraphDef graph, SavedModelBundle bundle) {
- TensorFlowModel model = new TensorFlowModel(modelName);
- OperationIndex index = new OperationIndex();
-
- importSignatures(graph, model);
- importNodes(graph, model, index);
- findDimensionNames(model, index);
- importExpressions(model, index, bundle);
-
- reportWarnings(model, index);
- logVariableTypes(index);
-
- return model;
- }
-
- private static void importSignatures(MetaGraphDef graph, TensorFlowModel model) {
- for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
- String signatureName = signatureEntry.getKey();
- TensorFlowModel.Signature signature = model.signature(signatureName);
-
- Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap();
- for (Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) {
- String inputName = input.getKey();
- signature.input(inputName, namePartOf(input.getValue().getName()));
- }
-
- Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap();
- for (Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) {
- String outputName = output.getKey();
- signature.output(outputName, namePartOf(output.getValue().getName()));
- }
- }
- }
-
- private static boolean isSignatureInput(TensorFlowModel model, TensorFlowOperation operation) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String inputName : signature.inputs().values()) {
- if (inputName.equals(operation.node().getName())) {
- return true;
- }
- }
- }
- return false;
- }
-
- private static boolean isSignatureOutput(TensorFlowModel model, TensorFlowOperation operation) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- if (outputName.equals(operation.node().getName())) {
- return true;
- }
- }
- }
- return false;
- }
-
- private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- importNode(model.name(), outputName, graph.getGraphDef(), index);
- }
- }
- }
-
- private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) {
- if (index.alreadyImported(nodeName)) {
- return index.get(nodeName);
- }
- NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph);
- List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index);
- TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName));
- index.put(nodeName, operation);
-
- List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index);
- if (controlInputs.size() > 0) {
- operation.setControlInputs(controlInputs);
- }
-
- return operation;
- }
-
- private static List<TensorFlowOperation> importNodeInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
- return node.getInputList().stream()
- .filter(name -> ! isControlDependency(name))
- .map(nodeName -> importNode(modelName, nodeName, graph, index))
- .collect(Collectors.toList());
- }
-
- private static List<TensorFlowOperation> importControlInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
- return node.getInputList().stream()
- .filter(nodeName -> isControlDependency(nodeName))
- .map(nodeName -> importNode(modelName, nodeName, graph, index))
- .collect(Collectors.toList());
- }
-
- private static boolean isControlDependency(String name) {
- return name.startsWith("^");
- }
-
- /** Find dimension names to avoid excessive renaming while evaluating the model. */
- private static void findDimensionNames(TensorFlowModel model, OperationIndex index) {
- DimensionRenamer renamer = new DimensionRenamer();
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String output : signature.outputs().values()) {
- addDimensionNameConstraints(index.get(output), renamer);
- }
- }
- renamer.solve();
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String output : signature.outputs().values()) {
- renameDimensions(index.get(output), renamer);
- }
- }
- }
-
- private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
- operation.addDimensionNameConstraints(renamer);
- }
- }
-
- private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> renameDimensions(input, renamer));
- operation.renameDimensions(renamer);
- }
- }
-
- private static void importExpressions(TensorFlowModel model, OperationIndex index, SavedModelBundle bundle) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- try {
- Optional<TensorFunction> function = importExpression(index.get(outputName), model, bundle);
- if (!function.isPresent()) {
- signature.skippedOutput(outputName, "No valid output function could be found.");
- }
- }
- catch (IllegalArgumentException e) {
- signature.skippedOutput(outputName, Exceptions.toMessageString(e));
- }
- }
- }
- }
-
- private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) {
- if (!operation.type().isPresent()) {
- return Optional.empty();
- }
- if (operation.isConstant()) {
- return importConstant(model, operation, bundle);
- }
-
- importInputExpressions(operation, model, bundle);
- importRankingExpression(model, operation);
- importInputExpression(model, operation);
- importMacroExpression(model, operation);
-
- return operation.function();
- }
-
- private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model,
- SavedModelBundle bundle) {
- operation.inputs().forEach(input -> importExpression(input, model, bundle));
- }
-
- private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) {
- if (operation.macro().isPresent()) {
- TensorFunction function = operation.macro().get();
- try {
- model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString()));
- }
- catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function +
- " cannot be parsed as a ranking expression", e);
- }
- }
- }
-
- private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation,
- SavedModelBundle bundle) {
- String name = operation.vespaName();
- if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
- return operation.function();
- }
-
- Tensor tensor;
- if (operation.getConstantValue().isPresent()) {
- Value value = operation.getConstantValue().get();
- if ( ! (value instanceof TensorValue)) {
- return operation.function(); // scalar values are inserted directly into the expression
- }
- tensor = value.asTensor();
- } else {
- // Here we use the type from the operation, which will have correct dimension names after name resolving
- tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle),
- operation.type().get());
- operation.setConstantValue(new TensorValue(tensor));
- }
-
- if (tensor.type().rank() == 0) {
- model.smallConstant(name, tensor);
- } else {
- model.largeConstant(name, tensor);
- }
- return operation.function();
- }
-
- static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
- Session.Runner fetched = bundle.session().runner().fetch(name);
- List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
- if (importedTensors.size() != 1)
- throw new IllegalStateException("Expected 1 tensor from fetching " + name +
- ", but got " + importedTensors.size());
- return importedTensors.get(0);
- }
-
- private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) {
- if (operation.function().isPresent()) {
- String name = operation.node().getName();
- if (!model.expressions().containsKey(operation.node().getName())) {
- TensorFunction function = operation.function().get();
-
- // Make sure output adheres to standard naming convention
- if (isSignatureOutput(model, operation)) {
- OrderedTensorType operationType = operation.type().get();
- OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node());
- if ( ! operationType.equals(standardNamingType)) {
- List<String> renameFrom = operationType.dimensionNames();
- List<String> renameTo = standardNamingType.dimensionNames();
- function = new Rename(function, renameFrom, renameTo);
- }
- }
-
- try {
- // We add all intermediate nodes imported as separate expressions. Only
- // those referenced in a signature output will be used. We parse the
- // TensorFunction here to convert it to a RankingExpression tree.
- model.expression(name, new RankingExpression(name, function.toString()));
- }
- catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function +
- " cannot be parsed as a ranking expression", e);
- }
- }
- }
- }
-
- private static void importInputExpression(TensorFlowModel model, TensorFlowOperation operation) {
- if (operation.isInput() && isSignatureInput(model, operation)) {
- // All inputs must have dimensions with standard naming convention: d0, d1, ...
- OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node());
- model.argument(operation.node().getName(), standardNamingConvention.type());
- model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
- }
- }
-
- private static void reportWarnings(TensorFlowModel model, OperationIndex index) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String output : signature.outputs().values()) {
- reportWarnings(index.get(output), signature);
- }
- }
- }
-
- /**
- * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type.
- * This allows users to learn the exact types (including dimension order after renaming) of the Variables
- * such that these can be converted and fed to a parent document independently of the rest of the model
- * for fast model weight updates.
- */
- private static void logVariableTypes(OperationIndex index) {
- for (TensorFlowOperation operation : index.operations()) {
- if ( ! (operation instanceof Variable)) continue;
- if ( ! operation.type().isPresent()) continue; // will not happen
-
- log.info("Importing TensorFlow variable " + operation.node().getName() + " as " + operation.vespaName() +
- " of type " + operation.type().get());
- }
- }
-
- private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) {
- for (String warning : operation.warnings()) {
- signature.importWarning(warning);
- }
- for (TensorFlowOperation input : operation.inputs()) {
- reportWarnings(input, signature);
- }
- }
-
- private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) {
- for (NodeDef node : graph.getNodeList()) {
- if (node.getName().equals(name)) {
- return node;
- }
- }
- throw new IllegalArgumentException("Could not find node '" + name + "'");
- }
-
- /**
- * A method signature input and output has the form name:index.
- * This returns the name part without the index.
- */
- private static String namePartOf(String name) {
- name = name.startsWith("^") ? name.substring(1) : name;
- return name.split(":")[0];
- }
-
- /**
- * This return the output port part. Indexes are used for nodes with
- * multiple outputs.
- */
- private static int portPartOf(String name) {
- int i = name.indexOf(":");
- return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
- }
-
- private static class OperationIndex {
-
- private final Map<String, TensorFlowOperation> index = new HashMap<>();
- public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); }
- public TensorFlowOperation get(String key) { return index.get(key); }
- public boolean alreadyImported(String key) { return index.containsKey(key); }
- public Collection<TensorFlowOperation> operations() { return index.values(); }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java
deleted file mode 100644
index c1665d066a4..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java
+++ /dev/null
@@ -1,210 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
-
-import java.util.ArrayDeque;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Deque;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Optional;
-
-/**
- * A constraint satisfier to find suitable dimension names to reduce the
- * amount of necessary renaming during evaluation of an imported model.
- *
- * @author lesters
- */
-public class DimensionRenamer {
-
- private final String dimensionPrefix;
- private final Map<String, List<Integer>> variables = new HashMap<>();
- private final Map<Arc, Constraint> constraints = new HashMap<>();
- private final Map<String, Integer> renames = new HashMap<>();
-
- private int iterations = 0;
-
- public DimensionRenamer() {
- this("d");
- }
-
- public DimensionRenamer(String dimensionPrefix) {
- this.dimensionPrefix = dimensionPrefix;
- }
-
- /**
- * Add a dimension name variable.
- */
- public void addDimension(String name) {
- variables.computeIfAbsent(name, d -> new ArrayList<>());
- }
-
- /**
- * Add a constraint between dimension names.
- */
- public void addConstraint(String from, String to, Constraint pred, TensorFlowOperation operation) {
- Arc arc = new Arc(from, to, operation);
- Arc opposite = arc.opposite();
- constraints.put(arc, pred);
- constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric
- }
-
- /**
- * Retrieve resulting name of dimension after solving for constraints.
- */
- public Optional<String> dimensionNameOf(String name) {
- if (!renames.containsKey(name)) {
- return Optional.empty();
- }
- return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name)));
- }
-
- /**
- * Perform iterative arc consistency until we have found a solution. After
- * an initial iteration, the variables (dimensions) will have multiple
- * valid values. Find a single valid assignment by iteratively locking one
- * dimension after another, and running the arc consistency algorithm
- * multiple times.
- *
- * This requires having constraints that result in an absolute ordering:
- * equals, lesserThan and greaterThan do that, but adding notEquals does
- * not typically result in a guaranteed ordering. If that is needed, the
- * algorithm below needs to be adapted with a backtracking (tree) search
- * to find solutions.
- */
- public void solve(int maxIterations) {
- initialize();
-
- // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
-
- for (String dimension : variables.keySet()) {
- List<Integer> values = variables.get(dimension);
- if (values.size() > 1) {
- if (!ac3()) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution.");
- }
- values.sort(Integer::compare);
- variables.put(dimension, Collections.singletonList(values.get(0)));
- }
- renames.put(dimension, variables.get(dimension).get(0));
- if (iterations > maxIterations) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution within " +
- maxIterations + " iterations");
- }
- }
-
- // Todo: handle failure more gracefully:
- // If a solution can't be found, look at the operation node in the arc
- // with the most remaining constraints, and inject a rename operation.
- // Then run this algorithm again.
- }
-
- public void solve() {
- solve(100000);
- }
-
- private void initialize() {
- for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) {
- List<Integer> values = variable.getValue();
- for (int i = 0; i < variables.size(); ++i) {
- values.add(i); // invariant: values are in increasing order
- }
- }
- }
-
- private boolean ac3() {
- Deque<Arc> workList = new ArrayDeque<>(constraints.keySet());
- while (!workList.isEmpty()) {
- Arc arc = workList.pop();
- iterations += 1;
- if (revise(arc)) {
- if (variables.get(arc.from).size() == 0) {
- return false; // no solution found
- }
- for (Arc constraint : constraints.keySet()) {
- if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) {
- workList.add(constraint);
- }
- }
- }
- }
- return true;
- }
-
- private boolean revise(Arc arc) {
- boolean revised = false;
- for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) {
- Integer from = fromIterator.next();
- boolean satisfied = false;
- for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) {
- Integer to = toIterator.next();
- if (constraints.get(arc).test(from, to)) {
- satisfied = true;
- }
- }
- if (!satisfied) {
- fromIterator.remove();
- revised = true;
- }
- }
- return revised;
- }
-
- public interface Constraint {
- boolean test(Integer x, Integer y);
- }
-
- public static boolean equals(Integer x, Integer y) {
- return Objects.equals(x, y);
- }
-
- public static boolean lesserThan(Integer x, Integer y) {
- return x < y;
- }
-
- public static boolean greaterThan(Integer x, Integer y) {
- return x > y;
- }
-
- private static class Arc {
-
- private final String from;
- private final String to;
- private final TensorFlowOperation operation;
-
- Arc(String from, String to, TensorFlowOperation operation) {
- this.from = from;
- this.to = to;
- this.operation = operation;
- }
-
- Arc opposite() {
- return new Arc(to, from, operation);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(from, to);
- }
-
- @Override
- public boolean equals(Object obj) {
- if (obj == null || !(obj instanceof Arc)) {
- return false;
- }
- Arc other = (Arc) obj;
- return Objects.equals(from, other.from) && Objects.equals(to, other.to);
- }
-
- @Override
- public String toString() {
- return String.format("%s -> %s", from, to);
- }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
deleted file mode 100644
index b665413a6b2..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ConcatV2;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ExpandDims;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Identity;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Join;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Map;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Matmul;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Mean;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Merge;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.NoOp;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Placeholder;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.PlaceholderWithDefault;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Reshape;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Select;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Shape;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Squeeze;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Switch;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.List;
-
-/**
- * Maps from TensorFlow operations to Vespa operations.
- *
- * @author bratseth
- * @author lesters
- */
-public class OperationMapper {
-
- public static TensorFlowOperation get(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- switch (node.getOp().toLowerCase()) {
- // array ops
- case "concatv2": return new ConcatV2(modelName, node, inputs, port);
- case "const": return new Const(modelName, node, inputs, port);
- case "expanddims": return new ExpandDims(modelName, node, inputs, port);
- case "identity": return new Identity(modelName, node, inputs, port);
- case "placeholder": return new Placeholder(modelName, node, inputs, port);
- case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, node, inputs, port);
- case "reshape": return new Reshape(modelName, node, inputs, port);
- case "shape": return new Shape(modelName, node, inputs, port);
- case "squeeze": return new Squeeze(modelName, node, inputs, port);
-
- // control flow
- case "merge": return new Merge(modelName, node, inputs, port);
- case "switch": return new Switch(modelName, node, inputs, port);
-
- // math ops
- case "add": return new Join(modelName, node, inputs, port, ScalarFunctions.add());
- case "add_n": return new Join(modelName, node, inputs, port, ScalarFunctions.add());
- case "acos": return new Map(modelName, node, inputs, port, ScalarFunctions.acos());
- case "div": return new Join(modelName, node, inputs, port, ScalarFunctions.divide());
- case "realdiv": return new Join(modelName, node, inputs, port, ScalarFunctions.divide());
- case "floor": return new Map(modelName, node, inputs, port, ScalarFunctions.floor());
- case "matmul": return new Matmul(modelName, node, inputs, port);
- case "maximum": return new Join(modelName, node, inputs, port, ScalarFunctions.max());
- case "mean": return new Mean(modelName, node, inputs, port);
- case "reducemean": return new Mean(modelName, node, inputs, port);
- case "mul": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply());
- case "multiply": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply());
- case "rsqrt": return new Map(modelName, node, inputs, port, ScalarFunctions.rsqrt());
- case "select": return new Select(modelName, node, inputs, port);
- case "where3": return new Select(modelName, node, inputs, port);
- case "sigmoid": return new Map(modelName, node, inputs, port, ScalarFunctions.sigmoid());
- case "squareddifference": return new Join(modelName, node, inputs, port, ScalarFunctions.squareddifference());
- case "sub": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract());
- case "subtract": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract());
-
- // nn ops
- case "biasadd": return new Join(modelName, node, inputs, port, ScalarFunctions.add());
- case "elu": return new Map(modelName, node, inputs, port, ScalarFunctions.elu());
- case "relu": return new Map(modelName, node, inputs, port, ScalarFunctions.relu());
- case "selu": return new Map(modelName, node, inputs, port, ScalarFunctions.selu());
-
- // state ops
- case "variable": return new Variable(modelName, node, inputs, port);
- case "variablev2": return new Variable(modelName, node, inputs, port);
-
- // evaluation no-ops
- case "stopgradient":return new Identity(modelName, node, inputs, port);
- case "noop": return new NoOp(modelName, node, inputs, port);
- }
-
- TensorFlowOperation op = new NoOp(modelName, node, inputs, port);
- op.warning("Operation '" + node.getOp() + "' is currently not implemented");
- return op;
- }
-
-}
-
-
-
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
deleted file mode 100644
index 03a65333192..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
+++ /dev/null
@@ -1,255 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
-
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.TensorTypeParser;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.TensorShapeProto;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Optional;
-import java.util.stream.Collectors;
-
-/**
- * A Vespa tensor type is ordered by the lexicographical ordering of dimension
- * names. TensorFlow tensors have an explicit ordering of their dimensions.
- * During import, we need to track the Vespa dimension that matches the
- * corresponding TensorFlow dimension as the ordering can change after
- * dimension renaming. That is the purpose of this class.
- *
- * @author lesters
- */
-public class OrderedTensorType {
-
- private final TensorType type;
- private final List<TensorType.Dimension> dimensions;
-
- private final long[] innerSizesTensorFlow;
- private final long[] innerSizesVespa;
- private final int[] dimensionMap;
-
- private OrderedTensorType(List<TensorType.Dimension> dimensions) {
- this.dimensions = Collections.unmodifiableList(dimensions);
- this.type = new TensorType.Builder(dimensions).build();
- this.innerSizesTensorFlow = new long[dimensions.size()];
- this.innerSizesVespa = new long[dimensions.size()];
- this.dimensionMap = createDimensionMap();
- }
-
- public TensorType type() {
- return this.type;
- }
-
- public int rank() { return dimensions.size(); }
-
- public List<TensorType.Dimension> dimensions() {
- return dimensions;
- }
-
- public List<String> dimensionNames() {
- return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList());
- }
-
- private int[] createDimensionMap() {
- int numDimensions = dimensions.size();
- if (numDimensions == 0) {
- return null;
- }
- innerSizesTensorFlow[numDimensions - 1] = 1;
- innerSizesVespa[numDimensions - 1] = 1;
- for (int i = numDimensions - 1; --i >= 0; ) {
- innerSizesTensorFlow[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesTensorFlow[i+1];
- innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1];
- }
- int[] mapping = new int[numDimensions];
- for (int i = 0; i < numDimensions; ++i) {
- TensorType.Dimension dim1 = dimensions().get(i);
- for (int j = 0; j < numDimensions; ++j) {
- TensorType.Dimension dim2 = type.dimensions().get(j);
- if (dim1.equals(dim2)) {
- mapping[i] = j;
- break;
- }
- }
- }
- return mapping;
- }
-
- /**
- * When dimension ordering between Vespa and TensorFlow differs, i.e.
- * after dimension renaming, use the dimension map to read in values
- * so that they are correctly laid out in memory for Vespa.
- * Used when importing tensors from TensorFlow.
- */
- public int toDirectIndex(int index) {
- if (dimensions.size() == 0) {
- return 0;
- }
- if (dimensionMap == null) {
- throw new IllegalArgumentException("Dimension map is not available");
- }
- int directIndex = 0;
- long rest = index;
- for (int i = 0; i < dimensions.size(); ++i) {
- long address = rest / innerSizesTensorFlow[i];
- directIndex += innerSizesVespa[dimensionMap[i]] * address;
- rest %= innerSizesTensorFlow[i];
- }
- return directIndex;
- }
-
- @Override
- public boolean equals(Object obj) {
- if (obj == null || !(obj instanceof OrderedTensorType)) {
- return false;
- }
- OrderedTensorType other = (OrderedTensorType) obj;
- if (dimensions.size() != dimensions.size()) {
- return false;
- }
- List<TensorType.Dimension> thisDimensions = this.dimensions();
- List<TensorType.Dimension> otherDimensions = other.dimensions();
- for (int i = 0; i < thisDimensions.size(); ++i) {
- if (!thisDimensions.get(i).equals(otherDimensions.get(i))) {
- return false;
- }
- }
- return true;
- }
-
- public void verifyType(NodeDef node) {
- TensorShapeProto shape = tensorFlowShape(node);
- if (shape != null) {
- if (shape.getDimCount() != type.rank()) {
- throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
- "does not match Vespa shape");
- }
- for (int tensorFlowIndex = 0; tensorFlowIndex < dimensions.size(); ++tensorFlowIndex) {
- int vespaIndex = dimensionMap[tensorFlowIndex];
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
- TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex);
- if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
- throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
- "does not match Vespa dimensions");
- }
- }
- }
- }
-
- private static TensorShapeProto tensorFlowShape(NodeDef node) {
- AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
- if (attrValueList == null) {
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "does not exist");
- }
- if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "is not of expected type");
- }
- List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
- return shapeList.get(0); // support multiple outputs?
- }
-
- public OrderedTensorType rename(DimensionRenamer renamer) {
- List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
- for (TensorType.Dimension dimension : dimensions) {
- String oldName = dimension.name();
- Optional<String> newName = renamer.dimensionNameOf(oldName);
- if (!newName.isPresent())
- return this; // presumably, already renamed
- TensorType.Dimension.Type dimensionType = dimension.type();
- if (dimensionType == TensorType.Dimension.Type.indexedBound) {
- renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get()));
- } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) {
- renamedDimensions.add(TensorType.Dimension.indexed(newName.get()));
- } else if (dimensionType == TensorType.Dimension.Type.mapped) {
- renamedDimensions.add(TensorType.Dimension.mapped(newName.get()));
- }
- }
- return new OrderedTensorType(renamedDimensions);
- }
-
- /**
- * Returns a string representation of this: A standard tensor type string where dimensions
- * are listed in the order of this rather than in the natural order of their names.
- */
- @Override
- public String toString() {
- return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
- }
-
- /**
- * Creates an instance from the string representation of this: A standard tensor type string
- * where dimensions are listed in the order of this rather than the natural order of their names.
- */
- public static OrderedTensorType fromSpec(String typeSpec) {
- return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
- }
-
- public static OrderedTensorType fromTensorFlowType(NodeDef node) {
- return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
- }
-
- public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
- Builder builder = new Builder(node);
- TensorShapeProto shape = tensorFlowShape(node);
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
- if (tensorFlowDimension.getSize() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
- } else {
- builder.add(TensorType.Dimension.indexed(dimensionName));
- }
- }
- return builder.build();
- }
-
- public static class Builder {
-
- private final TensorShapeProto shape;
- private final List<TensorType.Dimension> dimensions;
-
- public Builder(NodeDef node) {
- this.shape = tensorFlowShape(node);
- this.dimensions = new ArrayList<>(shape.getDimCount());
- }
-
- public Builder add(TensorType.Dimension vespaDimension) {
- int index = dimensions.size();
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(index);
- long size = tensorFlowDimension.getSize();
- if (size >= 0) {
- if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
- throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension types");
- }
- if (!vespaDimension.size().isPresent()) {
- throw new IllegalArgumentException("Tensor dimension is indexed bound but does " +
- "not have a size");
- }
- if (vespaDimension.size().get() != size) {
- throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension sizes. TensorFlow: " + size + " Vespa: " +
- vespaDimension.size().get());
- }
- } else {
- if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
- throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension types");
- }
- }
- this.dimensions.add(vespaDimension);
- return this;
- }
-
- public OrderedTensorType build() {
- return new OrderedTensorType(dimensions);
- }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
deleted file mode 100644
index 6cbfe0dfb05..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
+++ /dev/null
@@ -1,145 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Optional;
-import java.util.function.DoubleBinaryOperator;
-
-public class Join extends TensorFlowOperation {
-
- private final DoubleBinaryOperator operator;
-
- public Join(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleBinaryOperator operator) {
- super(modelName, node, inputs, port);
- this.operator = operator;
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType a = largestInput().type().get();
- OrderedTensorType b = smallestInput().type().get();
-
- // Well now we have potentially entered the wonderful world of "broadcasting"
- // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
- // In broadcasting, the size of each dimension is compared element-wise,
- // starting with the trailing dimensions and working forward. A special
- // case occurs when the size of one dimension is 1, while the other is not.
- // Then the dimension with size 1 is "stretched" to be of compatible size.
- //
- // An example:
- //
- // Tensor A: d0[5], d1[1], d2[3], d3[1]
- // Tensor B: d1[4], d2[1], d3[2]
- //
- // In TensorFlow and using the above rules of broadcasting, the resulting
- // type is:
- // d0[5], d1[4], d2[3], d2[2]
- //
- // However, in Vespa's tensor logic, the join of the two above tensors would
- // result in a tensor of type:
- // d0[5], d1[1], d2[1], d3[1]
- //
- // By reducing the dimensions of size 1 in each tensor before joining,
- // we get equal results as in TensorFlow.
-
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
- int sizeDifference = a.rank() - b.rank();
- for (int i = 0; i < a.rank(); ++i) {
- TensorType.Dimension aDim = a.dimensions().get(i);
- long size = aDim.size().orElse(-1L);
-
- if (i - sizeDifference >= 0) {
- TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference);
- size = Math.max(size, bDim.size().orElse(-1L));
- }
-
- if (aDim.type() == TensorType.Dimension.Type.indexedBound) {
- builder.add(TensorType.Dimension.indexed(aDim.name(), size));
- } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) {
- builder.add(TensorType.Dimension.indexed(aDim.name()));
- } else if (aDim.type() == TensorType.Dimension.Type.mapped) {
- builder.add(TensorType.Dimension.mapped(aDim.name()));
- }
- }
- return builder.build();
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
-
- TensorFlowOperation a = largestInput();
- TensorFlowOperation b = smallestInput();
-
- List<String> aDimensionsToReduce = new ArrayList<>();
- List<String> bDimensionsToReduce = new ArrayList<>();
- int sizeDifference = a.type().get().rank() - b.type().get().rank();
- for (int i = 0; i < b.type().get().rank(); ++i) {
- TensorType.Dimension bDim = b.type().get().dimensions().get(i);
- TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference);
- long bSize = bDim.size().orElse(-1L);
- long aSize = aDim.size().orElse(-1L);
- if (bSize == 1L && aSize != 1L) {
- bDimensionsToReduce.add(bDim.name());
- }
- if (aSize == 1L && bSize != 1L) {
- aDimensionsToReduce.add(bDim.name());
- }
- }
-
- TensorFunction aReducedFunction = a.function().get();
- if (aDimensionsToReduce.size() > 0) {
- aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce);
- }
- TensorFunction bReducedFunction = b.function().get();
- if (bDimensionsToReduce.size() > 0) {
- bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
- }
-
- return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
- OrderedTensorType a = largestInput().type().get();
- OrderedTensorType b = smallestInput().type().get();
- int sizeDifference = a.rank() - b.rank();
- for (int i = 0; i < b.rank(); ++i) {
- String bDim = b.dimensions().get(i).name();
- String aDim = a.dimensions().get(i + sizeDifference).name();
- renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
- }
- }
-
- private TensorFlowOperation largestInput() {
- OrderedTensorType a = inputs.get(0).type().get();
- OrderedTensorType b = inputs.get(1).type().get();
- return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
- }
-
- private TensorFlowOperation smallestInput() {
- OrderedTensorType a = inputs.get(0).type().get();
- OrderedTensorType b = inputs.get(1).type().get();
- return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java
deleted file mode 100644
index b2b9530a161..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java
+++ /dev/null
@@ -1,74 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.List;
-import java.util.Optional;
-
-public class Matmul extends TensorFlowOperation {
-
- public Matmul(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
- typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
- typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
- return typeBuilder.build();
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType aType = inputs.get(0).type().get();
- OrderedTensorType bType = inputs.get(1).type().get();
- if (aType.type().rank() < 2 || bType.type().rank() < 2)
- throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
- if (aType.type().rank() != bType.type().rank())
- throw new IllegalArgumentException("Tensors in matmul must have the same rank");
-
- Optional<TensorFunction> aFunction = inputs.get(0).function();
- Optional<TensorFunction> bFunction = inputs.get(1).function();
- if (!aFunction.isPresent() || !bFunction.isPresent()) {
- return null;
- }
- return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
- List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
- List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
-
- String aDim0 = aDimensions.get(0).name();
- String aDim1 = aDimensions.get(1).name();
- String bDim0 = bDimensions.get(0).name();
- String bDim1 = bDimensions.get(1).name();
-
- // The second dimension of a should have the same name as the first dimension of b
- renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this);
-
- // The first dimension of a should have a different name than the second dimension of b
- renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this);
-
- // For efficiency, the dimensions to join over should be innermost - soft constraint
- renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
- renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java
deleted file mode 100644
index d558ec89e87..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.Collections;
-import java.util.List;
-
-public class NoOp extends TensorFlowOperation {
-
- public NoOp(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, Collections.emptyList(), port); // don't propagate inputs
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return null;
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- return null;
- }
-
- @Override
- public boolean isConstant() {
- return true;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
deleted file mode 100644
index b18a8a9b212..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.List;
-
-public class Variable extends TensorFlowOperation {
-
- public Variable(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
- }
-
- /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName() + "_" + super.vespaName();
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromTensorFlowType(node, super.vespaName() + "_");
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- return null; // will be added by function() since this is constant.
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
- }
-
- @Override
- public boolean isConstant() {
- return true;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java
deleted file mode 100644
index 9e53990a9d6..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java
+++ /dev/null
@@ -1,8 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-/**
- * Tensorflow integration
- */
-@ExportPackage
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/Benchmark.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/Benchmark.java
index bb2110a0f5f..51a1b09b9fa 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/Benchmark.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/Benchmark.java
@@ -16,7 +16,7 @@ import java.util.LinkedList;
import java.util.List;
/**
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
*/
public final class Benchmark {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/StreamEvaluationBenchmark.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/StreamEvaluationBenchmark.java
index 280ffc6278b..760e056327c 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/StreamEvaluationBenchmark.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/StreamEvaluationBenchmark.java
@@ -147,11 +147,6 @@ public class StreamEvaluationBenchmark {
new StreamEvaluationBenchmark().run();
}
- private void assertEqualish(double a,double b) {
- if (Math.abs(a-b) >= Math.abs((a+b)/100000000) )
- throw new RuntimeException("Expected value " + a + " but optimized evaluation produced " + b);
- }
-
private void bindStreamingFeatures(Map<String, Double> featureItem, Context context) {
for (Map.Entry<String, Double> feature : featureItem.entrySet())
context.put(feature.getKey(), feature.getValue());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
index 0f5eec93feb..bf9684082f4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
@@ -15,7 +15,7 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved");
- TensorFlowModel.Signature signature = model.get().signature("serving_default");
+ ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java
index 74b0d11f1d6..c8c7ec798bb 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java
@@ -1,6 +1,6 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import org.junit.Test;
import static org.junit.Assert.assertTrue;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
index 50a467ec581..a63c7346335 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.TensorType;
@@ -24,7 +24,7 @@ public class DropoutImportTestCase {
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
model.get().requiredMacros().get("X"));
- TensorFlowModel.Signature signature = model.get().signature("serving_default");
+ ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
@@ -32,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/Maximum", output.getName());
- assertEquals("join(join(tf_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
+ assertEquals("join(join(imported_ml_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java
index 9f919c452d6..bd7644be23b 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
@@ -45,7 +45,7 @@ public class MnistSoftmaxImportTestCase {
// Check signatures
assertEquals(1, model.get().signatures().size());
- TensorFlowModel.Signature signature = model.get().signatures().get("serving_default");
+ ImportedModel.Signature signature = model.get().signatures().get("serving_default");
assertNotNull(signature);
// ... signature inputs
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
index 4b68cd40a08..a7926cd2e02 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
@@ -1,11 +1,9 @@
-package com.yahoo.searchlib.rankingexpression.integration.onnx;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -24,7 +22,7 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() throws IOException {
- OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
// Check constants
assertEquals(2, model.largeConstants().size());
@@ -48,7 +46,7 @@ public class OnnxMnistSoftmaxImportTestCase {
model.requiredMacros().get("Placeholder"));
// Check outputs
- RankingExpression output = model.outputExpression("add");
+ RankingExpression output = model.defaultSignature().outputExpression("add");
assertNotNull(output);
assertEquals("add", output.getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
@@ -68,13 +66,12 @@ public class OnnxMnistSoftmaxImportTestCase {
}
private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) {
- SavedModelBundle tensorFlowModel = SavedModelBundle.load(path, "serve");
- TensorFlowModel model = new TensorFlowImporter().importModel("test", tensorFlowModel);
+ ImportedModel model = new TensorFlowImporter().importModel("test", path);
return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
}
private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) {
- OnnxModel model = new OnnxImporter().importModel("test", path);
+ ImportedModel model = new OnnxImporter().importModel("test", path);
return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
}
@@ -83,14 +80,7 @@ public class OnnxMnistSoftmaxImportTestCase {
return expression.evaluate(context).asTensor();
}
- private Context contextFrom(TensorFlowModel result) {
- MapContext context = new MapContext();
- result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
- result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
- return context;
- }
-
- private Context contextFrom(OnnxModel result) {
+ private Context contextFrom(ImportedModel result) {
MapContext context = new MapContext();
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java
index beec2ab1ead..b2443082ab1 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java
@@ -1,6 +1,6 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
index 7ca16939477..723c5f27914 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
@@ -1,11 +1,11 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
@@ -28,7 +28,7 @@ import static org.junit.Assert.assertEquals;
public class TestableTensorFlowModel {
private SavedModelBundle tensorFlowModel;
- private TensorFlowModel model;
+ private ImportedModel model;
// Sizes of the input vector
private final int d0Size = 1;
@@ -39,7 +39,7 @@ public class TestableTensorFlowModel {
model = new TensorFlowImporter().importModel(modelName, tensorFlowModel);
}
- public TensorFlowModel get() { return model; }
+ public ImportedModel get() { return model; }
public void assertEqualResult(String inputName, String operationName) {
Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName);
@@ -66,7 +66,7 @@ public class TestableTensorFlowModel {
return TensorConverter.toVespaTensor(results.get(0));
}
- private Context contextFrom(TensorFlowModel result) {
+ private Context contextFrom(ImportedModel result) {
MapContext context = new MapContext();
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
@@ -81,7 +81,7 @@ public class TestableTensorFlowModel {
return b.build();
}
- private void evaluateMacro(Context context, TensorFlowModel model, String macroName) {
+ private void evaluateMacro(Context context, ImportedModel model, String macroName) {
if (!context.names().contains(macroName)) {
RankingExpression e = model.macros().get(macroName);
evaluateMacroDependencies(context, model, e.getRoot());
@@ -89,7 +89,7 @@ public class TestableTensorFlowModel {
}
}
- private void evaluateMacroDependencies(Context context, TensorFlowModel model, ExpressionNode node) {
+ private void evaluateMacroDependencies(Context context, ImportedModel model, ExpressionNode node) {
if (node instanceof ReferenceNode) {
String name = node.toString();
if (model.macros().containsKey(name)) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java
index 051c2c60c95..f94098e6255 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java
@@ -1,4 +1,4 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import org.junit.Test;
diff --git a/searchsummary/CMakeLists.txt b/searchsummary/CMakeLists.txt
index 5f6e8881f13..4df636e0219 100644
--- a/searchsummary/CMakeLists.txt
+++ b/searchsummary/CMakeLists.txt
@@ -24,6 +24,7 @@ vespa_define_module(
TESTS
src/tests/docsumformat
src/tests/docsummary
+ src/tests/docsummary/attribute_combiner
src/tests/docsummary/slime_summary
src/tests/extractkeywords
)
diff --git a/searchsummary/src/tests/docsummary/attribute_combiner/CMakeLists.txt b/searchsummary/src/tests/docsummary/attribute_combiner/CMakeLists.txt
new file mode 100644
index 00000000000..df323b9c982
--- /dev/null
+++ b/searchsummary/src/tests/docsummary/attribute_combiner/CMakeLists.txt
@@ -0,0 +1,8 @@
+# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(searchsummary_attribute_combiner_test_app TEST
+ SOURCES
+ attribute_combiner_test.cpp
+ DEPENDS
+ searchsummary
+)
+vespa_add_test(NAME searchsummary_attribute_combiner_test_app COMMAND searchsummary_attribute_combiner_test_app)
diff --git a/searchsummary/src/tests/docsummary/attribute_combiner/attribute_combiner_test.cpp b/searchsummary/src/tests/docsummary/attribute_combiner/attribute_combiner_test.cpp
new file mode 100644
index 00000000000..97fafd0a446
--- /dev/null
+++ b/searchsummary/src/tests/docsummary/attribute_combiner/attribute_combiner_test.cpp
@@ -0,0 +1,217 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/searchcommon/common/undefinedvalues.h>
+#include <vespa/searchlib/attribute/attributefactory.h>
+#include <vespa/searchlib/attribute/attributemanager.h>
+#include <vespa/searchlib/attribute/attributevector.h>
+#include <vespa/searchlib/attribute/attributevector.hpp>
+#include <vespa/searchlib/attribute/floatbase.h>
+#include <vespa/searchlib/attribute/integerbase.h>
+#include <vespa/searchlib/attribute/stringbase.h>
+#include <vespa/searchlib/util/slime_output_raw_buf_adapter.h>
+#include <vespa/searchsummary/docsummary/docsumstate.h>
+#include <vespa/searchsummary/docsummary/docsum_field_writer_state.h>
+#include <vespa/searchsummary/docsummary/attribute_combiner_dfw.h>
+#include <vespa/vespalib/data/slime/slime.h>
+#include <vespa/vespalib/testkit/testapp.h>
+
+#include <vespa/log/log.h>
+LOG_SETUP("attribute_combiner_test");
+
+using search::AttributeFactory;
+using search::AttributeManager;
+using search::AttributeVector;
+using search::IntegerAttribute;
+using search::FloatingPointAttribute;
+using search::StringAttribute;
+using search::attribute::BasicType;
+using search::attribute::CollectionType;
+using search::attribute::Config;
+using search::attribute::IAttributeVector;
+using search::attribute::getUndefined;
+using search::docsummary::AttributeCombinerDFW;
+using search::docsummary::GetDocsumsState;
+using search::docsummary::GetDocsumsStateCallback;
+using search::docsummary::IDocsumEnvironment;
+using search::docsummary::IDocsumFieldWriter;
+
+namespace {
+
+vespalib::string
+toCompactJsonString(const vespalib::Slime &slime)
+{
+ vespalib::SimpleBuffer buf;
+ vespalib::slime::JsonFormat::encode(slime, buf, true);
+ return buf.get().make_string();
+}
+
+struct FieldBlock {
+ vespalib::string input;
+ vespalib::Slime slime;
+ search::RawBuf binary;
+ vespalib::string json;
+
+ explicit FieldBlock(const vespalib::string &jsonInput)
+ : input(jsonInput), slime(), binary(1024), json()
+ {
+ size_t used = vespalib::slime::JsonFormat::decode(jsonInput, slime);
+ EXPECT_TRUE(used > 0);
+ json = toCompactJsonString(slime);
+ search::SlimeOutputRawBufAdapter adapter(binary);
+ vespalib::slime::BinaryFormat::encode(slime, adapter);
+ }
+ const char *data() const { return binary.GetDrainPos(); }
+ size_t dataLen() const { return binary.GetUsedLen(); }
+};
+
+struct AttributeManagerFixture
+{
+ AttributeManager mgr;
+
+ AttributeManagerFixture();
+
+ ~AttributeManagerFixture();
+
+ template <typename AttributeType, typename ValueType>
+ void
+ buildAttribute(const vespalib::string &name,
+ BasicType type,
+ std::vector<std::vector<ValueType>> values);
+
+ void
+ buildStringAttribute(const vespalib::string &name,
+ std::vector<std::vector<vespalib::string>> values);
+ void
+ buildFloatAttribute(const vespalib::string &name,
+ std::vector<std::vector<double>> values);
+
+ void
+ buildIntegerAttribute(const vespalib::string &name,
+ BasicType type,
+ std::vector<std::vector<IAttributeVector::largeint_t>> values);
+};
+
+AttributeManagerFixture::AttributeManagerFixture()
+ : mgr()
+{
+ buildStringAttribute("array.name", {{"n1.1", "n1.2"}, {"n2"}, {"n3.1", "n3.2"}, {"", "n4.2"}});
+ buildIntegerAttribute("array.val", BasicType::Type::INT8, {{ 10, 11}, {20, 21 }, {30}, { getUndefined<int8_t>(), 41}});
+ buildFloatAttribute("array.fval", {{ 110.0}, { 120.0, 121.0 }, { 130.0, 131.0}, { getUndefined<double>(), 141.0 }});
+}
+
+AttributeManagerFixture::~AttributeManagerFixture() = default;
+
+template <typename AttributeType, typename ValueType>
+void
+AttributeManagerFixture::buildAttribute(const vespalib::string &name,
+ BasicType type,
+ std::vector<std::vector<ValueType>> values)
+{
+ Config cfg(type, CollectionType::Type::ARRAY);
+ auto attrBase = AttributeFactory::createAttribute(name, cfg);
+ EXPECT_TRUE(attrBase);
+ auto attr = std::dynamic_pointer_cast<AttributeType>(attrBase);
+ EXPECT_TRUE(attr);
+ attr->addReservedDoc();
+ for (const auto &docValues : values) {
+ uint32_t docId = 0;
+ EXPECT_TRUE(attr->addDoc(docId));
+ EXPECT_NOT_EQUAL(0u, docId);
+ for (const auto &value : docValues) {
+ attr->append(docId, value, 1);
+ }
+ attr->commit();
+ }
+ EXPECT_TRUE(mgr.add(attr));
+}
+
+void
+AttributeManagerFixture::buildStringAttribute(const vespalib::string &name,
+ std::vector<std::vector<vespalib::string>> values)
+{
+ buildAttribute<StringAttribute, vespalib::string>(name, BasicType::Type::STRING, std::move(values));
+}
+
+void
+AttributeManagerFixture::buildFloatAttribute(const vespalib::string &name,
+ std::vector<std::vector<double>> values)
+{
+ buildAttribute<FloatingPointAttribute, double>(name, BasicType::Type::DOUBLE, std::move(values));
+}
+
+void
+AttributeManagerFixture::buildIntegerAttribute(const vespalib::string &name,
+ BasicType type,
+ std::vector<std::vector<IAttributeVector::largeint_t>> values)
+{
+ buildAttribute<IntegerAttribute, IAttributeVector::largeint_t>(name, type, std::move(values));
+}
+
+
+class DummyStateCallback : public GetDocsumsStateCallback
+{
+public:
+ void FillSummaryFeatures(GetDocsumsState *, IDocsumEnvironment *) override { }
+ void FillRankFeatures(GetDocsumsState *, IDocsumEnvironment *) override { }
+ void ParseLocation(GetDocsumsState *) override { }
+ ~DummyStateCallback() override { }
+};
+
+
+struct Fixture
+{
+ AttributeManagerFixture attrs;
+ std::unique_ptr<IDocsumFieldWriter> writer;
+ DummyStateCallback stateCallback;
+ GetDocsumsState state;
+
+ Fixture();
+ ~Fixture();
+ void assertWritten(const vespalib::string &exp, uint32_t docId);
+};
+
+Fixture::Fixture()
+ : attrs(),
+ writer(AttributeCombinerDFW::create("array", attrs.mgr)),
+ stateCallback(),
+ state(stateCallback)
+{
+ EXPECT_TRUE(writer->setFieldWriterStateIndex(0));
+ state._attrCtx = attrs.mgr.createContext();
+ state._fieldWriterStates.resize(1);
+}
+
+Fixture::~Fixture()
+{
+}
+
+void
+Fixture::assertWritten(const vespalib::string &expectedJson, uint32_t docId)
+{
+ vespalib::Slime target;
+ vespalib::slime::SlimeInserter inserter(target);
+ writer->insertField(docId, nullptr, &state, search::docsummary::RES_JSONSTRING, inserter);
+ search::RawBuf binary(1024);
+ vespalib::string json = toCompactJsonString(target);
+ search::SlimeOutputRawBufAdapter adapter(binary);
+ vespalib::slime::BinaryFormat::encode(target, adapter);
+ FieldBlock block(expectedJson);
+ if (!EXPECT_EQUAL(block.dataLen(), binary.GetUsedLen()) ||
+ !EXPECT_EQUAL(0, memcmp(block.data(), binary.GetDrainPos(), block.dataLen()))) {
+ LOG(error, "Expected '%s'", expectedJson.c_str());
+ LOG(error, "Expected normalized '%s'", block.json.c_str());
+ LOG(error, "Got '%s'", json.c_str());
+ }
+}
+
+TEST_F("require that attributes combiner dfw generates correct slime output for array of struct", Fixture())
+{
+ f.assertWritten("[ { fval: 110.0, name: \"n1.1\", val: 10}, { name: \"n1.2\", val: 11}]", 1);
+ f.assertWritten("[ { fval: 120.0, name: \"n2\", val: 20}, { fval: 121.0, val: 21 }]", 2);
+ f.assertWritten("[ { fval: 130.0, name: \"n3.1\", val: 30}, { fval: 131.0, name: \"n3.2\"} ]", 3);
+ f.assertWritten("[ { }, { fval: 141.0, name: \"n4.2\", val: 41} ]", 4);
+}
+
+}
+
+TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/CMakeLists.txt b/searchsummary/src/vespa/searchsummary/docsummary/CMakeLists.txt
index 9009f0bcbc7..ce54e7b0ea7 100644
--- a/searchsummary/src/vespa/searchsummary/docsummary/CMakeLists.txt
+++ b/searchsummary/src/vespa/searchsummary/docsummary/CMakeLists.txt
@@ -1,6 +1,9 @@
# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
vespa_add_library(searchsummary_docsummary OBJECT
SOURCES
+ array_attribute_combiner_dfw.cpp
+ attribute_combiner_dfw.cpp
+ attribute_field_writer.cpp
resultclass.cpp
resultconfig.cpp
resultpacker.cpp
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.cpp b/searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.cpp
new file mode 100644
index 00000000000..84e329f159d
--- /dev/null
+++ b/searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.cpp
@@ -0,0 +1,89 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "array_attribute_combiner_dfw.h"
+#include "docsum_field_writer_state.h"
+#include "attribute_field_writer.h"
+#include <vespa/searchcommon/attribute/iattributecontext.h>
+#include <vespa/searchcommon/attribute/iattributevector.h>
+#include <vespa/vespalib/data/slime/cursor.h>
+
+using search::attribute::IAttributeContext;
+using search::attribute::IAttributeVector;
+using vespalib::slime::Cursor;
+
+namespace search::docsummary {
+
+namespace {
+
+class ArrayAttributeFieldWriterState : public DocsumFieldWriterState
+{
+ std::vector<std::unique_ptr<AttributeFieldWriter>> _writers;
+
+public:
+ ArrayAttributeFieldWriterState(const std::vector<vespalib::string> &fieldNames,
+ const std::vector<vespalib::string> &attributeNames,
+ IAttributeContext &context);
+ ~ArrayAttributeFieldWriterState() override;
+ void insertField(uint32_t docId, vespalib::slime::Inserter &target) override;
+};
+
+ArrayAttributeFieldWriterState::ArrayAttributeFieldWriterState(const std::vector<vespalib::string> &fieldNames,
+ const std::vector<vespalib::string> &attributeNames,
+ IAttributeContext &context)
+ : DocsumFieldWriterState()
+{
+ size_t fields = fieldNames.size();
+ _writers.reserve(fields);
+ for (uint32_t field = 0; field < fields; ++field) {
+ const IAttributeVector *attr = context.getAttribute(attributeNames[field]);
+ if (attr != nullptr) {
+ _writers.emplace_back(AttributeFieldWriter::create(fieldNames[field], *attr));
+ }
+ }
+}
+
+ArrayAttributeFieldWriterState::~ArrayAttributeFieldWriterState() = default;
+
+void
+ArrayAttributeFieldWriterState::insertField(uint32_t docId, vespalib::slime::Inserter &target)
+{
+ uint32_t elems = 0;
+ for (auto &writer : _writers) {
+ writer->fetch(docId);
+ if (elems < writer->size()) {
+ elems = writer->size();
+ }
+ }
+ Cursor &arr = target.insertArray();
+ for (uint32_t idx = 0; idx < elems; ++idx) {
+ Cursor &obj = arr.addObject();
+ for (auto &writer : _writers) {
+ writer->print(idx, obj);
+ }
+ }
+}
+
+}
+
+ArrayAttributeCombinerDFW::ArrayAttributeCombinerDFW(const vespalib::string &fieldName,
+ const std::vector<vespalib::string> &fields)
+ : AttributeCombinerDFW(fieldName),
+ _fields(fields),
+ _attributeNames()
+{
+ _attributeNames.reserve(_fields.size());
+ vespalib::string prefix = fieldName + ".";
+ for (const auto &field : _fields) {
+ _attributeNames.emplace_back(prefix + field);
+ }
+}
+
+ArrayAttributeCombinerDFW::~ArrayAttributeCombinerDFW() = default;
+
+std::unique_ptr<DocsumFieldWriterState>
+ArrayAttributeCombinerDFW::allocFieldWriterState(IAttributeContext &context)
+{
+ return std::make_unique<ArrayAttributeFieldWriterState>(_fields, _attributeNames, context);
+}
+
+}
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.h b/searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.h
new file mode 100644
index 00000000000..c02d2bd5da6
--- /dev/null
+++ b/searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.h
@@ -0,0 +1,29 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "attribute_combiner_dfw.h"
+
+namespace search::attribute { class IAttributeContext; }
+
+namespace search::docsummary {
+
+class DocsumFieldWriterState;
+
+/*
+ * This class reads values from multiple struct field attributes and
+ * inserts them as an array of struct.
+ */
+class ArrayAttributeCombinerDFW : public AttributeCombinerDFW
+{
+ std::vector<vespalib::string> _fields;
+ std::vector<vespalib::string> _attributeNames;
+
+ std::unique_ptr<DocsumFieldWriterState> allocFieldWriterState(search::attribute::IAttributeContext &context) override;
+public:
+ ArrayAttributeCombinerDFW(const vespalib::string &fieldName,
+ const std::vector<vespalib::string> &fields);
+ ~ArrayAttributeCombinerDFW() override;
+};
+
+}
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.cpp b/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.cpp
new file mode 100644
index 00000000000..b532cfb273a
--- /dev/null
+++ b/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.cpp
@@ -0,0 +1,141 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "attribute_combiner_dfw.h"
+#include "array_attribute_combiner_dfw.h"
+#include "docsum_field_writer_state.h"
+#include "docsumstate.h"
+#include <vespa/searchlib/attribute/attributeguard.h>
+#include <vespa/searchlib/attribute/attributevector.h>
+#include <vespa/searchlib/attribute/iattributemanager.h>
+#include <algorithm>
+
+#include <vespa/log/log.h>
+LOG_SETUP(".searchsummary.docsummary.attribute_combiner_dfw");
+
+using search::AttributeGuard;
+using search::AttributeVector;
+using search::attribute::CollectionType;
+
+namespace search::docsummary {
+
+namespace {
+
+class StructFields
+{
+ std::vector<vespalib::string> _mapFields;
+ std::vector<vespalib::string> _arrayFields;
+ bool _hasMapKey;
+ bool _error;
+
+public:
+ StructFields(const vespalib::string &fieldName, const IAttributeManager &attrMgr);
+ ~StructFields();
+ const std::vector<vespalib::string> &getMapFields() const { return _mapFields; }
+ const std::vector<vespalib::string> &getArrayFields() const { return _arrayFields; }
+ bool hasMapKey() const { return _hasMapKey; }
+ bool getError() const { return _error; }
+};
+
+
+StructFields::StructFields(const vespalib::string &fieldName, const IAttributeManager &attrMgr)
+ : _mapFields(),
+ _arrayFields(),
+ _hasMapKey(false),
+ _error(false)
+{
+ // Note: Doesn't handle imported attributes
+ std::vector<AttributeGuard> attrs;
+ attrMgr.getAttributeList(attrs);
+ vespalib::string prefix = fieldName + ".";
+ vespalib::string keyName = prefix + "key";
+ vespalib::string valuePrefix = prefix + "value.";
+ for (const auto &guard : attrs) {
+ vespalib::string name = guard->getName();
+ if (name.substr(0, prefix.size()) != prefix) {
+ continue;
+ }
+ auto collType = guard->getCollectionType();
+ if (collType != CollectionType::Type::ARRAY) {
+ LOG(warning, "Attribute %s is not an array attribute", name.c_str());
+ _error = true;
+ break;
+ }
+ if (name.substr(0, valuePrefix.size()) == valuePrefix) {
+ _mapFields.emplace_back(name.substr(valuePrefix.size()));
+ } else {
+ _arrayFields.emplace_back(name.substr(prefix.size()));
+ if (name == keyName) {
+ _hasMapKey = true;
+ }
+ }
+ }
+ if (!_error) {
+ std::sort(_arrayFields.begin(), _arrayFields.end());
+ std::sort(_mapFields.begin(), _mapFields.end());
+ if (!_mapFields.empty()) {
+ if (!_hasMapKey) {
+ LOG(warning, "Missing key attribute '%s', have value attributes for map", keyName.c_str());
+ _error = true;
+ } else if (_arrayFields.size() != 1u) {
+ LOG(warning, "Could not determine if field '%s' is array or map of struct", fieldName.c_str());
+ _error = true;
+ }
+ }
+ }
+}
+
+StructFields::~StructFields() = default;
+
+}
+
+AttributeCombinerDFW::AttributeCombinerDFW(const vespalib::string &fieldName)
+ : IDocsumFieldWriter(),
+ _stateIndex(0),
+ _fieldName(fieldName)
+{
+}
+
+AttributeCombinerDFW::~AttributeCombinerDFW() = default;
+
+bool
+AttributeCombinerDFW::IsGenerated() const
+{
+ return true;
+}
+
+bool
+AttributeCombinerDFW::setFieldWriterStateIndex(uint32_t fieldWriterStateIndex)
+{
+ _stateIndex = fieldWriterStateIndex;
+ return true;
+}
+
+std::unique_ptr<IDocsumFieldWriter>
+AttributeCombinerDFW::create(const vespalib::string &fieldName, IAttributeManager &attrMgr)
+{
+ StructFields structFields(fieldName, attrMgr);
+ if (structFields.getError()) {
+ return std::unique_ptr<IDocsumFieldWriter>();
+ } else if (!structFields.getMapFields().empty()) {
+ LOG(warning, "map of struct is not yet supported for field '%s'", fieldName.c_str());
+ return std::unique_ptr<IDocsumFieldWriter>();
+ }
+ return std::make_unique<ArrayAttributeCombinerDFW>(fieldName, structFields.getArrayFields());
+}
+
+void
+AttributeCombinerDFW::insertField(uint32_t docid,
+ GeneralResult *,
+ GetDocsumsState *state,
+ ResType,
+ vespalib::slime::Inserter &target)
+{
+ auto &fieldWriterState = state->_fieldWriterStates[_stateIndex];
+ if (!fieldWriterState) {
+ fieldWriterState = allocFieldWriterState(*state->_attrCtx);
+ }
+ fieldWriterState->insertField(docid, target);
+}
+
+}
+
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.h b/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.h
new file mode 100644
index 00000000000..ef54522a923
--- /dev/null
+++ b/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.h
@@ -0,0 +1,36 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "docsumfieldwriter.h"
+
+namespace search::attribute { class IAttributeContext; }
+
+namespace search::docsummary {
+
+class DocsumFieldWriterState;
+class DynamicDocsumWriter;
+
+/*
+ * This class reads values from multiple struct field attributes and
+ * inserts them as an array of struct or a map of struct.
+ */
+class AttributeCombinerDFW : public IDocsumFieldWriter
+{
+protected:
+ uint32_t _stateIndex;
+ vespalib::string _fieldName;
+ AttributeCombinerDFW(const vespalib::string &fieldName);
+protected:
+ virtual std::unique_ptr<DocsumFieldWriterState> allocFieldWriterState(search::attribute::IAttributeContext &context) = 0;
+public:
+ ~AttributeCombinerDFW() override;
+ bool IsGenerated() const override;
+ bool setFieldWriterStateIndex(uint32_t fieldWriterStateIndex) override;
+ static std::unique_ptr<IDocsumFieldWriter> create(const vespalib::string &fieldName, IAttributeManager &attrMgr);
+ void insertField(uint32_t docid, GeneralResult *gres, GetDocsumsState *state,
+ ResType type, vespalib::slime::Inserter &target) override;
+};
+
+}
+
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.cpp b/searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.cpp
new file mode 100644
index 00000000000..2eebe7137dc
--- /dev/null
+++ b/searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.cpp
@@ -0,0 +1,172 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "attribute_field_writer.h"
+#include <vespa/searchcommon/attribute/attributecontent.h>
+#include <vespa/searchcommon/common/undefinedvalues.h>
+#include <vespa/vespalib/data/slime/cursor.h>
+#include <cassert>
+
+using search::attribute::BasicType;
+using search::attribute::IAttributeVector;
+using search::attribute::getUndefined;
+using vespalib::slime::Cursor;
+
+namespace search::docsummary {
+
+AttributeFieldWriter::AttributeFieldWriter(const vespalib::string &fieldName,
+ const IAttributeVector &attr)
+ : _fieldName(fieldName),
+ _attr(attr),
+ _size(0)
+{
+}
+
+AttributeFieldWriter::~AttributeFieldWriter() = default;
+
+namespace {
+
+template <class Content>
+class WriteField : public AttributeFieldWriter
+{
+protected:
+ Content _content;
+
+ WriteField(const vespalib::string &fieldName, const IAttributeVector &attr);
+ ~WriteField() override;
+private:
+ void fetch(uint32_t docId) override;
+};
+
+class WriteStringField : public WriteField<search::attribute::ConstCharContent>
+{
+public:
+ WriteStringField(const vespalib::string &fieldName,
+ const IAttributeVector &attr);
+ ~WriteStringField() override;
+ void print(uint32_t idx, Cursor &cursor) override;
+};
+
+
+class WriteFloatField : public WriteField<search::attribute::FloatContent>
+{
+public:
+ WriteFloatField(const vespalib::string &fieldName,
+ const IAttributeVector &attr);
+ ~WriteFloatField() override;
+ void print(uint32_t idx, Cursor &cursor) override;
+};
+
+class WriteIntField : public WriteField<search::attribute::IntegerContent>
+{
+ IAttributeVector::largeint_t _undefined;
+public:
+ WriteIntField(const vespalib::string &fieldName,
+ const IAttributeVector &attr,
+ IAttributeVector::largeint_t undefined);
+ ~WriteIntField() override;
+ void print(uint32_t idx, Cursor &cursor) override;
+};
+
+template <class Content>
+WriteField<Content>::WriteField(const vespalib::string &fieldName, const IAttributeVector &attr)
+ : AttributeFieldWriter(fieldName, attr),
+ _content()
+{
+}
+
+template <class Content>
+WriteField<Content>::~WriteField() = default;
+
+template <class Content>
+void
+WriteField<Content>::fetch(uint32_t docId)
+{
+ _content.fill(_attr, docId);
+ _size = _content.size();
+}
+
+WriteStringField::WriteStringField(const vespalib::string &fieldName,
+ const IAttributeVector &attr)
+ : WriteField(fieldName, attr)
+{
+}
+
+WriteStringField::~WriteStringField() = default;
+
+void
+WriteStringField::print(uint32_t idx, Cursor &cursor)
+{
+ if (idx < _size) {
+ const char *s = _content[idx];
+ if (s[0] != '\0') {
+ cursor.setString(_fieldName, vespalib::Memory(s));
+ }
+ }
+}
+
+WriteFloatField::WriteFloatField(const vespalib::string &fieldName,
+ const IAttributeVector &attr)
+ : WriteField(fieldName, attr)
+{
+}
+
+WriteFloatField::~WriteFloatField() = default;
+
+void
+WriteFloatField::print(uint32_t idx, Cursor &cursor)
+{
+ if (idx < _size) {
+ double val = _content[idx];
+ if (!search::attribute::isUndefined(val)) {
+ cursor.setDouble(_fieldName, val);
+ }
+ }
+}
+
+WriteIntField::WriteIntField(const vespalib::string &fieldName,
+ const IAttributeVector &attr,
+ IAttributeVector::largeint_t undefined)
+ : WriteField(fieldName, attr),
+ _undefined(undefined)
+{
+}
+
+WriteIntField::~WriteIntField() = default;
+
+void
+WriteIntField::print(uint32_t idx, Cursor &cursor)
+{
+ if (idx < _size) {
+ auto val = _content[idx];
+ if (val != _undefined) {
+ cursor.setLong(_fieldName, _content[idx]);
+ }
+ }
+}
+
+}
+
+std::unique_ptr<AttributeFieldWriter>
+AttributeFieldWriter::create(const vespalib::string &fieldName, const IAttributeVector &attr)
+{
+ switch (attr.getBasicType()) {
+ case BasicType::INT8:
+ return std::make_unique<WriteIntField>(fieldName, attr, getUndefined<int8_t>());
+ case BasicType::INT16:
+ return std::make_unique<WriteIntField>(fieldName, attr, getUndefined<int16_t>());
+ case BasicType::INT32:
+ return std::make_unique<WriteIntField>(fieldName, attr, getUndefined<int32_t>());
+ case BasicType::INT64:
+ return std::make_unique<WriteIntField>(fieldName, attr, getUndefined<int64_t>());
+ case BasicType::FLOAT:
+ case BasicType::DOUBLE:
+ return std::make_unique<WriteFloatField>(fieldName, attr);
+ case BasicType::STRING:
+ return std::make_unique<WriteStringField>(fieldName, attr);
+ default:
+ assert(false);
+ abort();
+ }
+}
+
+}
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.h b/searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.h
new file mode 100644
index 00000000000..104455a0e79
--- /dev/null
+++ b/searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.h
@@ -0,0 +1,34 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/vespalib/data/memory.h>
+
+namespace search::attribute { class IAttributeVector; }
+namespace vespalib::slime { class Cursor; }
+
+namespace search::docsummary {
+
+/*
+ * This class reads values from a struct field attribute and inserts
+ * them into proper position in an array of struct or map of struct.
+ * If the value to be inserted is considered to be undefined then
+ * the value is not inserted.
+ */
+class AttributeFieldWriter
+{
+protected:
+ const vespalib::Memory _fieldName;
+ const search::attribute::IAttributeVector &_attr;
+ size_t _size;
+public:
+ AttributeFieldWriter(const vespalib::string &fieldName,
+ const search::attribute::IAttributeVector &attr);
+ virtual ~AttributeFieldWriter();
+ virtual void fetch(uint32_t docId) = 0;
+ virtual void print(uint32_t idx, vespalib::slime::Cursor &cursor) = 0;
+ static std::unique_ptr<AttributeFieldWriter> create(const vespalib::string &fieldName, const search::attribute::IAttributeVector &attr);
+ uint32_t size() const { return _size; }
+};
+
+}
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_state.h b/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_state.h
new file mode 100644
index 00000000000..940cfd6ce06
--- /dev/null
+++ b/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_state.h
@@ -0,0 +1,21 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+namespace vespalib::slime { class Inserter; }
+
+namespace search::docsummary {
+
+/*
+ * A subclass of this class can be instantiated by a document field writer to
+ * track extra state during handling of a document summary request and
+ * insert the field value using that state.
+ */
+class DocsumFieldWriterState
+{
+public:
+ virtual void insertField(uint32_t docId, vespalib::slime::Inserter &target) = 0;
+ virtual ~DocsumFieldWriterState() = default;
+};
+
+}
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.cpp b/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.cpp
index 7b463352155..18e7e471663 100644
--- a/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.cpp
+++ b/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.cpp
@@ -21,6 +21,12 @@ using search::common::Location;
const vespalib::string IDocsumFieldWriter::_empty("");
+bool
+IDocsumFieldWriter::setFieldWriterStateIndex(uint32_t)
+{
+ return false; // Don't need any field writer state by default
+}
+
//--------------------------------------------------------------------------
EmptyDFW::EmptyDFW() { }
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.h b/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.h
index abce5c12227..51079f7736e 100644
--- a/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.h
+++ b/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.h
@@ -40,6 +40,7 @@ public:
}
void setIndex(size_t v) { _index = v; }
size_t getIndex() const { return _index; }
+ virtual bool setFieldWriterStateIndex(uint32_t fieldWriterStateIndex);
private:
size_t _index;
static const vespalib::string _empty;
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.cpp b/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.cpp
index 91953612f6a..b0431b6e6ac 100644
--- a/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.cpp
+++ b/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.cpp
@@ -4,6 +4,7 @@
#include <vespa/juniper/rpinterface.h>
#include <vespa/searchcommon/attribute/iattributecontext.h>
#include <vespa/searchlib/common/location.h>
+#include "docsum_field_writer_state.h"
namespace search {
namespace docsummary {
@@ -19,6 +20,7 @@ GetDocsumsState::GetDocsumsState(GetDocsumsStateCallback &callback)
_docSumFieldSpace(_docSumFieldSpaceStore, sizeof(_docSumFieldSpaceStore)), // only alloc buffer if needed
_attrCtx(),
_attributes(),
+ _fieldWriterStates(),
_jsonStringer(),
_parsedLocation(),
_summaryFeatures(NULL),
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.h b/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.h
index 4ffed79043e..fa47d5244eb 100644
--- a/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.h
+++ b/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.h
@@ -23,6 +23,7 @@ namespace search::docsummary {
class GetDocsumsState;
class IDocsumEnvironment;
class KeywordExtractor;
+class DocsumFieldWriterState;
class GetDocsumsStateCallback
{
@@ -70,6 +71,7 @@ public:
char _docSumFieldSpaceStore[2048];
std::unique_ptr<search::attribute::IAttributeContext> _attrCtx;
std::vector<const search::attribute::IAttributeVector *> _attributes;
+ std::vector<std::unique_ptr<DocsumFieldWriterState>> _fieldWriterStates;
vespalib::JSONStringer _jsonStringer;
// used by AbsDistanceDFW
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.cpp b/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.cpp
index bf660b1319b..abd1780b773 100644
--- a/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.cpp
+++ b/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.cpp
@@ -2,6 +2,7 @@
#include "docsumwriter.h"
#include "docsumstate.h"
+#include "docsum_field_writer_state.h"
#include <vespa/searchlib/common/transport.h>
#include <vespa/searchlib/util/slime_output_raw_buf_adapter.h>
#include <vespa/searchlib/attribute/iattributemanager.h>
@@ -77,7 +78,6 @@ DynamicDocsumWriter::resolveInputClass(ResolveClassInfo &rci, uint32_t id) const
}
}
-
static void convertEntry(GetDocsumsState *state,
const ResConfigEntry *resCfg,
const ResEntry *entry,
@@ -194,6 +194,7 @@ DynamicDocsumWriter::DynamicDocsumWriter( ResultConfig *config, KeywordExtractor
_defaultOutputClass(ResultConfig::NoClassID()),
_numClasses(config->GetNumResultClasses()),
_numEnumValues(config->GetFieldNameEnum().GetNumEntries()),
+ _numFieldWriterStates(0),
_classInfoTable(nullptr),
_overrideTable(nullptr)
{
@@ -267,6 +268,9 @@ DynamicDocsumWriter::Override(const char *fieldName, IDocsumFieldWriter *writer)
writer->setIndex(fieldEnumValue);
_overrideTable[fieldEnumValue] = writer;
+ if (writer->setFieldWriterStateIndex(_numFieldWriterStates)) {
+ ++_numFieldWriterStates;
+ }
for (ResultConfig::iterator it(_resultConfig->begin()), mt(_resultConfig->end()); it != mt; it++) {
@@ -288,6 +292,7 @@ DynamicDocsumWriter::InitState(IAttributeManager & attrMan, GetDocsumsState *sta
state->_kwExtractor = _keywordExtractor;
state->_attrCtx = attrMan.createContext();
state->_attributes.resize(_numEnumValues);
+ state->_fieldWriterStates.resize(_numFieldWriterStates);
for (size_t i(0); i < state->_attributes.size(); i++) {
const IDocsumFieldWriter *fw = _overrideTable[i];
if (fw) {
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.h b/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.h
index 6ef21a71e74..92b26d5cf14 100644
--- a/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.h
+++ b/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.h
@@ -54,6 +54,7 @@ private:
uint32_t _defaultOutputClass;
uint32_t _numClasses;
uint32_t _numEnumValues;
+ uint32_t _numFieldWriterStates;
ResultClass::DynamicInfo *_classInfoTable;
IDocsumFieldWriter **_overrideTable;
diff --git a/service-monitor/pom.xml b/service-monitor/pom.xml
index 70f9d4aa655..b8065ed3636 100644
--- a/service-monitor/pom.xml
+++ b/service-monitor/pom.xml
@@ -64,6 +64,12 @@
<version>${project.version}</version>
</dependency>
<dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>vespa-athenz</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
<groupId>com.google.inject</groupId>
<artifactId>guice</artifactId>
<scope>provided</scope>
@@ -76,6 +82,23 @@
<scope>provided</scope>
</dependency>
<dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-core</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>httpclient</artifactId>
+ <version>4.5</version>
+ <!-- This is necessary to get 4.4's HostnameVerifier API of SSLConnectionSocketFactory::new -->
+ <scope>compile</scope>
+ </dependency>
+ <dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/ServiceStatusProvider.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/ServiceStatusProvider.java
index 35003313775..75e61eef772 100644
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/ServiceStatusProvider.java
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/ServiceStatusProvider.java
@@ -11,7 +11,13 @@ import com.yahoo.vespa.applicationmodel.ServiceType;
* @author hakon
*/
public interface ServiceStatusProvider {
- /** Get the {@link ServiceStatus} of a particular service. */
+ /**
+ * Get the {@link ServiceStatus} of a particular service.
+ *
+ * <p>{@link ServiceStatus#NOT_CHECKED NOT_CHECKED} must be returned if the
+ * service status provider does does not monitor the service status for
+ * the particular application, cluster, service type, and config id.
+ */
ServiceStatus getStatus(ApplicationId applicationId,
ClusterId clusterId,
ServiceType serviceType,
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGenerator.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGenerator.java
index ec2702bcfaf..cbdcce125cc 100644
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGenerator.java
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGenerator.java
@@ -1,13 +1,148 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.service.monitor.application;
+import com.yahoo.config.model.api.ApplicationInfo;
+import com.yahoo.config.model.api.HostInfo;
+import com.yahoo.config.model.api.ServiceInfo;
+import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.config.provision.Zone;
import com.yahoo.vespa.applicationmodel.ApplicationInstance;
+import com.yahoo.vespa.applicationmodel.ApplicationInstanceId;
+import com.yahoo.vespa.applicationmodel.ClusterId;
+import com.yahoo.vespa.applicationmodel.ConfigId;
+import com.yahoo.vespa.applicationmodel.HostName;
+import com.yahoo.vespa.applicationmodel.ServiceCluster;
+import com.yahoo.vespa.applicationmodel.ServiceClusterKey;
+import com.yahoo.vespa.applicationmodel.ServiceInstance;
+import com.yahoo.vespa.applicationmodel.ServiceStatus;
+import com.yahoo.vespa.applicationmodel.ServiceType;
+import com.yahoo.vespa.applicationmodel.TenantId;
import com.yahoo.vespa.service.monitor.ServiceStatusProvider;
+import com.yahoo.vespa.service.monitor.internal.ServiceId;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static com.yahoo.vespa.service.monitor.application.ConfigServerApplication.CONFIG_SERVER_APPLICATION;
/**
+ * Class to generate an ApplicationInstance given service status for a standard (deployed) application.
+ *
* @author hakon
*/
-public interface ApplicationInstanceGenerator {
- /** Make an ApplicationInstance based on current service status. */
- ApplicationInstance makeApplicationInstance(ServiceStatusProvider serviceStatusProvider);
+public class ApplicationInstanceGenerator {
+ public static final String CLUSTER_ID_PROPERTY_NAME = "clustername";
+
+ private final ApplicationInfo applicationInfo;
+ private final Zone zone;
+
+ public ApplicationInstanceGenerator(ApplicationInfo applicationInfo, Zone zone) {
+ this.applicationInfo = applicationInfo;
+ this.zone = zone;
+ }
+
+ public ApplicationInstance makeApplicationInstance(ServiceStatusProvider serviceStatusProvider) {
+ Map<ServiceClusterKey, Set<ServiceInstance>> groupedServiceInstances = new HashMap<>();
+
+ for (HostInfo host : applicationInfo.getModel().getHosts()) {
+ HostName hostName = new HostName(host.getHostname());
+ for (ServiceInfo serviceInfo : host.getServices()) {
+ ServiceClusterKey serviceClusterKey = toServiceClusterKey(serviceInfo);
+ ServiceInstance serviceInstance =
+ toServiceInstance(
+ applicationInfo.getApplicationId(),
+ serviceClusterKey.clusterId(),
+ serviceInfo,
+ hostName,
+ serviceStatusProvider);
+
+ if (!groupedServiceInstances.containsKey(serviceClusterKey)) {
+ groupedServiceInstances.put(serviceClusterKey, new HashSet<>());
+ }
+ groupedServiceInstances.get(serviceClusterKey).add(serviceInstance);
+ }
+ }
+
+ Set<ServiceCluster> serviceClusters = groupedServiceInstances.entrySet().stream()
+ .map(entry -> new ServiceCluster(
+ entry.getKey().clusterId(),
+ entry.getKey().serviceType(),
+ entry.getValue()))
+ .collect(Collectors.toSet());
+
+ ApplicationInstance applicationInstance = new ApplicationInstance(
+ new TenantId(applicationInfo.getApplicationId().tenant().toString()),
+ toApplicationInstanceId(applicationInfo, zone),
+ serviceClusters);
+
+ // Fill back-references
+ for (ServiceCluster serviceCluster : applicationInstance.serviceClusters()) {
+ serviceCluster.setApplicationInstance(applicationInstance);
+ for (ServiceInstance serviceInstance : serviceCluster.serviceInstances()) {
+ serviceInstance.setServiceCluster(serviceCluster);
+ }
+ }
+
+ return applicationInstance;
+ }
+
+ private ServiceInstance toServiceInstance(
+ ApplicationId applicationId,
+ ClusterId clusterId,
+ ServiceInfo serviceInfo,
+ HostName hostName,
+ ServiceStatusProvider serviceStatusProvider) {
+ ConfigId configId = toConfigId(serviceInfo);
+
+ ServiceStatus status = serviceStatusProvider.getStatus(
+ applicationId,
+ clusterId,
+ toServiceType(serviceInfo), configId);
+
+ return new ServiceInstance(configId, hostName, status);
+ }
+
+ private ApplicationInstanceId toApplicationInstanceId(ApplicationInfo applicationInfo, Zone zone) {
+ if (applicationInfo.getApplicationId().equals(CONFIG_SERVER_APPLICATION.getApplicationId())) {
+ // Removing this historical discrepancy would break orchestration during rollout.
+ // An alternative may be to use a feature flag and flip it between releases,
+ // once that's available.
+ return new ApplicationInstanceId(applicationInfo.getApplicationId().application().value());
+ } else {
+ return new ApplicationInstanceId(String.format("%s:%s:%s:%s",
+ applicationInfo.getApplicationId().application().value(),
+ zone.environment().value(),
+ zone.region().value(),
+ applicationInfo.getApplicationId().instance().value()));
+ }
+ }
+
+ public static ServiceId getServiceId(ApplicationInfo applicationInfo, ServiceInfo serviceInfo) {
+ return new ServiceId(
+ applicationInfo.getApplicationId(),
+ getClusterId(serviceInfo),
+ toServiceType(serviceInfo),
+ toConfigId(serviceInfo));
+ }
+
+ private static ClusterId getClusterId(ServiceInfo serviceInfo) {
+ return new ClusterId(serviceInfo.getProperty(CLUSTER_ID_PROPERTY_NAME).orElse(""));
+ }
+
+ private static ServiceClusterKey toServiceClusterKey(ServiceInfo serviceInfo) {
+ ClusterId clusterId = getClusterId(serviceInfo);
+ ServiceType serviceType = toServiceType(serviceInfo);
+ return new ServiceClusterKey(clusterId, serviceType);
+ }
+
+ private static ServiceType toServiceType(ServiceInfo serviceInfo) {
+ return new ServiceType(serviceInfo.getServiceType());
+ }
+
+ private static ConfigId toConfigId(ServiceInfo serviceInfo) {
+ return new ConfigId(serviceInfo.getConfigId());
+ }
}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerAppGenerator.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerAppGenerator.java
deleted file mode 100644
index 76ca59cf583..00000000000
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerAppGenerator.java
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.service.monitor.application;
-
-import com.yahoo.vespa.applicationmodel.ApplicationInstance;
-import com.yahoo.vespa.applicationmodel.ConfigId;
-import com.yahoo.vespa.applicationmodel.HostName;
-import com.yahoo.vespa.applicationmodel.ServiceCluster;
-import com.yahoo.vespa.applicationmodel.ServiceInstance;
-import com.yahoo.vespa.applicationmodel.ServiceStatus;
-import com.yahoo.vespa.service.monitor.ServiceStatusProvider;
-
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
-import java.util.stream.Collectors;
-
-/**
- * Class for generating an ApplicationInstance for the synthesized config server application.
- *
- * @author hakon
- */
-public class ConfigServerAppGenerator implements ApplicationInstanceGenerator {
- private final List<String> hostnames;
-
- public ConfigServerAppGenerator(List<String> hostnames) {
- this.hostnames = hostnames;
- }
-
- @Override
- public ApplicationInstance makeApplicationInstance(ServiceStatusProvider statusProvider) {
- Set<ServiceInstance> serviceInstances = hostnames.stream()
- .map(hostname -> makeServiceInstance(hostname, statusProvider))
- .collect(Collectors.toSet());
-
- ServiceCluster serviceCluster = new ServiceCluster(
- ConfigServerApplication.CLUSTER_ID,
- ConfigServerApplication.SERVICE_TYPE,
- serviceInstances);
-
- Set<ServiceCluster> serviceClusters = new HashSet<>();
- serviceClusters.add(serviceCluster);
-
- ApplicationInstance applicationInstance = new ApplicationInstance(
- ConfigServerApplication.TENANT_ID,
- ConfigServerApplication.APPLICATION_INSTANCE_ID,
- serviceClusters);
-
- // Fill back-references
- serviceCluster.setApplicationInstance(applicationInstance);
- for (ServiceInstance serviceInstance : serviceCluster.serviceInstances()) {
- serviceInstance.setServiceCluster(serviceCluster);
- }
-
- return applicationInstance;
- }
-
- private ServiceInstance makeServiceInstance(String hostname, ServiceStatusProvider statusProvider) {
- ConfigId configId = new ConfigId(ConfigServerApplication.CONFIG_ID_PREFIX + hostname);
- ServiceStatus status = statusProvider.getStatus(
- ConfigServerApplication.CONFIG_SERVER_APPLICATION.getApplicationId(),
- ConfigServerApplication.CLUSTER_ID,
- ConfigServerApplication.SERVICE_TYPE,
- configId);
-
- return new ServiceInstance(configId, new HostName(hostname), status);
- }
-}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerApplication.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerApplication.java
index 132bb0927b8..5ad38cebcfc 100644
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerApplication.java
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerApplication.java
@@ -1,12 +1,26 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.service.monitor.application;
+import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.config.model.api.ApplicationInfo;
+import com.yahoo.config.model.api.HostInfo;
+import com.yahoo.config.model.api.PortInfo;
+import com.yahoo.config.model.api.ServiceInfo;
import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.config.provision.NodeType;
import com.yahoo.vespa.applicationmodel.ApplicationInstanceId;
import com.yahoo.vespa.applicationmodel.ClusterId;
+import com.yahoo.vespa.applicationmodel.ConfigId;
import com.yahoo.vespa.applicationmodel.ServiceType;
import com.yahoo.vespa.applicationmodel.TenantId;
+import com.yahoo.vespa.service.monitor.internal.ModelGenerator;
+import com.yahoo.vespa.service.monitor.internal.health.ApplicationHealthMonitor;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
/**
* A service/application model of the config server with health status.
@@ -21,8 +35,44 @@ public class ConfigServerApplication extends HostedVespaApplication {
public static final ServiceType SERVICE_TYPE = new ServiceType("configserver");
public static final String CONFIG_ID_PREFIX = "configid.";
+ public static ConfigId configIdFrom(int index) {
+ return new ConfigId(CONFIG_ID_PREFIX + index);
+ }
+
private ConfigServerApplication() {
super("zone-config-servers", NodeType.config,
ClusterSpec.Type.admin, ClusterSpec.Id.from("zone-config-servers"));
}
+
+ public ApplicationInfo makeApplicationInfo(ConfigserverConfig config) {
+ List<HostInfo> hostInfos = new ArrayList<>();
+ List<ConfigserverConfig.Zookeeperserver> zooKeeperServers = config.zookeeperserver();
+ for (int index = 0; index < zooKeeperServers.size(); ++index) {
+ String hostname = zooKeeperServers.get(index).hostname();
+ hostInfos.add(makeHostInfo(hostname, config.httpport(), index));
+ }
+
+ return new ApplicationInfo(
+ CONFIG_SERVER_APPLICATION.getApplicationId(),
+ 0,
+ new HostsModel(hostInfos));
+ }
+
+ private static HostInfo makeHostInfo(String hostname, int port, int configIndex) {
+ PortInfo portInfo = new PortInfo(port, ApplicationHealthMonitor.PORT_TAGS_HEALTH);
+
+ Map<String, String> properties = new HashMap<>();
+ properties.put(ModelGenerator.CLUSTER_ID_PROPERTY_NAME, CLUSTER_ID.s());
+
+ ServiceInfo serviceInfo = new ServiceInfo(
+ // service name == service type for the first service of each type on each host
+ SERVICE_TYPE.s(),
+ SERVICE_TYPE.s(),
+ Collections.singletonList(portInfo),
+ properties,
+ configIdFrom(configIndex).s(),
+ hostname);
+
+ return new HostInfo(hostname, Collections.singletonList(serviceInfo));
+ }
}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/DeployedAppGenerator.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/DeployedAppGenerator.java
deleted file mode 100644
index 2691a8bf1ee..00000000000
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/DeployedAppGenerator.java
+++ /dev/null
@@ -1,127 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.service.monitor.application;
-
-import com.yahoo.config.model.api.ApplicationInfo;
-import com.yahoo.config.model.api.HostInfo;
-import com.yahoo.config.model.api.ServiceInfo;
-import com.yahoo.config.provision.ApplicationId;
-import com.yahoo.config.provision.Zone;
-import com.yahoo.vespa.applicationmodel.ApplicationInstance;
-import com.yahoo.vespa.applicationmodel.ApplicationInstanceId;
-import com.yahoo.vespa.applicationmodel.ClusterId;
-import com.yahoo.vespa.applicationmodel.ConfigId;
-import com.yahoo.vespa.applicationmodel.HostName;
-import com.yahoo.vespa.applicationmodel.ServiceCluster;
-import com.yahoo.vespa.applicationmodel.ServiceClusterKey;
-import com.yahoo.vespa.applicationmodel.ServiceInstance;
-import com.yahoo.vespa.applicationmodel.ServiceStatus;
-import com.yahoo.vespa.applicationmodel.ServiceType;
-import com.yahoo.vespa.applicationmodel.TenantId;
-import com.yahoo.vespa.service.monitor.ServiceStatusProvider;
-
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Map;
-import java.util.Set;
-import java.util.stream.Collectors;
-
-/**
- * Class to generate an ApplicationInstance given service status for a standard (deployed) application.
- *
- * @author hakon
- */
-public class DeployedAppGenerator implements ApplicationInstanceGenerator {
- public static final String CLUSTER_ID_PROPERTY_NAME = "clustername";
-
- private final ApplicationInfo applicationInfo;
- private final Zone zone;
-
- public DeployedAppGenerator(ApplicationInfo applicationInfo, Zone zone) {
- this.applicationInfo = applicationInfo;
- this.zone = zone;
- }
-
- @Override
- public ApplicationInstance makeApplicationInstance(ServiceStatusProvider serviceStatusProvider) {
- Map<ServiceClusterKey, Set<ServiceInstance>> groupedServiceInstances = new HashMap<>();
-
- for (HostInfo host : applicationInfo.getModel().getHosts()) {
- HostName hostName = new HostName(host.getHostname());
- for (ServiceInfo serviceInfo : host.getServices()) {
- ServiceClusterKey serviceClusterKey = toServiceClusterKey(serviceInfo);
- ServiceInstance serviceInstance =
- toServiceInstance(
- applicationInfo.getApplicationId(),
- serviceClusterKey.clusterId(),
- serviceInfo,
- hostName,
- serviceStatusProvider);
-
- if (!groupedServiceInstances.containsKey(serviceClusterKey)) {
- groupedServiceInstances.put(serviceClusterKey, new HashSet<>());
- }
- groupedServiceInstances.get(serviceClusterKey).add(serviceInstance);
- }
- }
-
- Set<ServiceCluster> serviceClusters = groupedServiceInstances.entrySet().stream()
- .map(entry -> new ServiceCluster(
- entry.getKey().clusterId(),
- entry.getKey().serviceType(),
- entry.getValue()))
- .collect(Collectors.toSet());
-
- ApplicationInstance applicationInstance = new ApplicationInstance(
- new TenantId(applicationInfo.getApplicationId().tenant().toString()),
- toApplicationInstanceId(applicationInfo, zone),
- serviceClusters);
-
- // Fill back-references
- for (ServiceCluster serviceCluster : applicationInstance.serviceClusters()) {
- serviceCluster.setApplicationInstance(applicationInstance);
- for (ServiceInstance serviceInstance : serviceCluster.serviceInstances()) {
- serviceInstance.setServiceCluster(serviceCluster);
- }
- }
-
- return applicationInstance;
- }
-
- static ClusterId getClusterId(ServiceInfo serviceInfo) {
- return new ClusterId(serviceInfo.getProperty(CLUSTER_ID_PROPERTY_NAME).orElse(""));
- }
-
- private ServiceClusterKey toServiceClusterKey(ServiceInfo serviceInfo) {
- ClusterId clusterId = getClusterId(serviceInfo);
- ServiceType serviceType = toServiceType(serviceInfo);
- return new ServiceClusterKey(clusterId, serviceType);
- }
-
- private ServiceInstance toServiceInstance(
- ApplicationId applicationId,
- ClusterId clusterId,
- ServiceInfo serviceInfo,
- HostName hostName,
- ServiceStatusProvider serviceStatusProvider) {
- ConfigId configId = new ConfigId(serviceInfo.getConfigId());
-
- ServiceStatus status = serviceStatusProvider.getStatus(
- applicationId,
- clusterId,
- toServiceType(serviceInfo), configId);
-
- return new ServiceInstance(configId, hostName, status);
- }
-
- private ApplicationInstanceId toApplicationInstanceId(ApplicationInfo applicationInfo, Zone zone) {
- return new ApplicationInstanceId(String.format("%s:%s:%s:%s",
- applicationInfo.getApplicationId().application().value(),
- zone.environment().value(),
- zone.region().value(),
- applicationInfo.getApplicationId().instance().value()));
- }
-
- private ServiceType toServiceType(ServiceInfo serviceInfo) {
- return new ServiceType(serviceInfo.getServiceType());
- }
-}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/HostsModel.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/HostsModel.java
new file mode 100644
index 00000000000..225ffb0adbc
--- /dev/null
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/HostsModel.java
@@ -0,0 +1,75 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.application;
+
+import com.yahoo.config.FileReference;
+import com.yahoo.config.model.api.FileDistribution;
+import com.yahoo.config.model.api.HostInfo;
+import com.yahoo.config.model.api.Model;
+import com.yahoo.config.provision.AllocatedHosts;
+import com.yahoo.vespa.config.ConfigKey;
+import com.yahoo.vespa.config.ConfigPayload;
+import com.yahoo.vespa.config.buildergen.ConfigDefinition;
+
+import java.time.Instant;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * Model that only supports the subset necessary to create an ApplicationInstance.
+ *
+ * @author hakon
+ */
+public class HostsModel implements Model {
+ private final Collection<HostInfo> hosts;
+
+ public HostsModel(List<HostInfo> hosts) {
+ this.hosts = Collections.unmodifiableCollection(hosts);
+ }
+
+ @Override
+ public Collection<HostInfo> getHosts() {
+ return hosts;
+ }
+
+ @Override
+ public ConfigPayload getConfig(ConfigKey<?> configKey, ConfigDefinition configDefinition) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Set<ConfigKey<?>> allConfigsProduced() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Set<String> allConfigIds() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void distributeFiles(FileDistribution fileDistribution) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Set<FileReference> fileReferences() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public AllocatedHosts allocatedHosts() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean allowModelVersionMismatch(Instant now) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean skipOldConfigModels(Instant now) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ZoneApplication.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ZoneApplication.java
index 6bbf0cb6d1d..c10015d3bfa 100644
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ZoneApplication.java
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ZoneApplication.java
@@ -21,8 +21,8 @@ public class ZoneApplication {
.createHostedVespaApplicationId("routing");
public static boolean isNodeAdminService(ApplicationId applicationId,
- ClusterId clusterId,
- ServiceType serviceType) {
+ ClusterId clusterId,
+ ServiceType serviceType) {
return Objects.equals(applicationId, ZONE_APPLICATION_ID) &&
Objects.equals(serviceType, ServiceType.CONTAINER) &&
Objects.equals(clusterId, ClusterId.NODE_ADMIN);
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModel.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModel.java
new file mode 100644
index 00000000000..80e0bfd2710
--- /dev/null
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModel.java
@@ -0,0 +1,42 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal;
+
+import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.config.model.api.ApplicationInfo;
+import com.yahoo.config.model.api.SuperModel;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static com.yahoo.vespa.service.monitor.application.ConfigServerApplication.CONFIG_SERVER_APPLICATION;
+
+/**
+ * The {@code DuperModel} unites the {@link com.yahoo.config.model.api.SuperModel SuperModel}
+ * with the synthetically produced applications like the config server application.
+ *
+ * @author hakon
+ */
+public class DuperModel {
+ private final List<ApplicationInfo> staticApplicationInfos = new ArrayList<>();
+
+ public DuperModel(ConfigserverConfig configServerConfig) {
+ // Single-tenant applications have the config server as part of the application model.
+ // TODO: Add health monitoring for config server when part of application model.
+ if (configServerConfig.multitenant()) {
+ staticApplicationInfos.add(CONFIG_SERVER_APPLICATION.makeApplicationInfo(configServerConfig));
+ }
+ }
+
+ /** For testing. */
+ DuperModel(ApplicationInfo... staticApplicationInfos) {
+ this.staticApplicationInfos.addAll(Arrays.asList(staticApplicationInfos));
+ }
+
+ public List<ApplicationInfo> getApplicationInfos(SuperModel superModelSnapshot) {
+ List<ApplicationInfo> allApplicationInfos = new ArrayList<>();
+ allApplicationInfos.addAll(staticApplicationInfos);
+ allApplicationInfos.addAll(superModelSnapshot.getAllApplicationInfos());
+ return allApplicationInfos;
+ }
+}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModelListener.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModelListener.java
new file mode 100644
index 00000000000..235c7db5c36
--- /dev/null
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModelListener.java
@@ -0,0 +1,28 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal;
+
+import com.yahoo.config.model.api.ApplicationInfo;
+import com.yahoo.config.model.api.SuperModel;
+import com.yahoo.config.provision.ApplicationId;
+
+/**
+ * Interface for listening for changes to the {@link DuperModel}.
+ *
+ * @author hakon
+ */
+public interface DuperModelListener {
+ /**
+ * An application has been activated:
+ *
+ * <ul>
+ * <li>A synthetic application like the config server application has been added/"activated"
+ * <li>A super model application has been activated (see
+ * {@link com.yahoo.config.model.api.SuperModelListener#applicationActivated(SuperModel, ApplicationInfo)
+ * SuperModelListener}
+ * </ul>
+ */
+ void applicationActivated(ApplicationInfo application);
+
+ /** Application has been removed. */
+ void applicationRemoved(ApplicationId id);
+}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ModelGenerator.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ModelGenerator.java
index 9da449289a7..ad2f223acf8 100644
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ModelGenerator.java
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ModelGenerator.java
@@ -1,56 +1,40 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.service.monitor.internal;
-import com.yahoo.config.model.api.SuperModel;
+import com.yahoo.config.model.api.ApplicationInfo;
import com.yahoo.config.provision.Zone;
import com.yahoo.vespa.applicationmodel.ApplicationInstance;
import com.yahoo.vespa.applicationmodel.ApplicationInstanceReference;
import com.yahoo.vespa.service.monitor.ServiceModel;
import com.yahoo.vespa.service.monitor.ServiceStatusProvider;
import com.yahoo.vespa.service.monitor.application.ApplicationInstanceGenerator;
-import com.yahoo.vespa.service.monitor.application.ConfigServerAppGenerator;
-import com.yahoo.vespa.service.monitor.application.DeployedAppGenerator;
-import java.util.ArrayList;
-import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
- * Util to convert SuperModel to ServiceModel and application model classes
+ * Util to make ServiceModel and its related application model classes
*/
public class ModelGenerator {
public static final String CLUSTER_ID_PROPERTY_NAME = "clustername";
- private final List<ApplicationInstanceGenerator> staticGenerators;
-
- public ModelGenerator(List<String> configServerHosts) {
- if (configServerHosts.isEmpty()) {
- staticGenerators = Collections.emptyList();
- } else {
- staticGenerators = Collections.singletonList(new ConfigServerAppGenerator(configServerHosts));
- }
- }
-
/**
* Create service model based primarily on super model.
*
* If the configServerhosts is non-empty, a config server application is added.
*/
- ServiceModel toServiceModel(
- SuperModel superModel,
- Zone zone,
- ServiceStatusProvider serviceStatusProvider) {
- List<ApplicationInstanceGenerator> generators = new ArrayList<>(staticGenerators);
- superModel.getAllApplicationInfos()
- .forEach(info -> generators.add(new DeployedAppGenerator(info, zone)));
-
- Map<ApplicationInstanceReference, ApplicationInstance> applicationInstances = generators.stream()
- .map(generator -> generator.makeApplicationInstance(serviceStatusProvider))
- .collect(Collectors.toMap(ApplicationInstance::reference, Function.identity()));
+ public ServiceModel toServiceModel(List<ApplicationInfo> allApplicationInfos,
+ Zone zone,
+ ServiceStatusProvider serviceStatusProvider) {
+ Map<ApplicationInstanceReference, ApplicationInstance> applicationInstances =
+ allApplicationInfos.stream()
+ .map(info -> new ApplicationInstanceGenerator(info, zone)
+ .makeApplicationInstance(serviceStatusProvider))
+ .collect(Collectors.toMap(ApplicationInstance::reference, Function.identity()));
return new ServiceModel(applicationInstances);
}
+
}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/MonitorManager.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/MonitorManager.java
index 49863672c43..1edf3a18215 100644
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/MonitorManager.java
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/MonitorManager.java
@@ -1,11 +1,10 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.service.monitor.internal;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal;
-import com.yahoo.config.model.api.SuperModelListener;
import com.yahoo.vespa.service.monitor.ServiceStatusProvider;
/**
* @author hakon
*/
-public interface MonitorManager extends SuperModelListener, ServiceStatusProvider {
+public interface MonitorManager extends DuperModelListener, ServiceStatusProvider {
}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceId.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceId.java
new file mode 100644
index 00000000000..993ea7fed5c
--- /dev/null
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceId.java
@@ -0,0 +1,75 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal;
+
+import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.vespa.applicationmodel.ClusterId;
+import com.yahoo.vespa.applicationmodel.ConfigId;
+import com.yahoo.vespa.applicationmodel.ServiceType;
+
+import javax.annotation.concurrent.Immutable;
+import java.util.Objects;
+
+/**
+ * Identifies a service.
+ *
+ * @author hakon
+ */
+@Immutable
+public class ServiceId {
+ private final ApplicationId applicationId;
+ private final ClusterId clusterId;
+ private final ServiceType serviceType;
+ private final ConfigId configId;
+
+ public ServiceId(ApplicationId applicationId,
+ ClusterId clusterId,
+ ServiceType serviceType,
+ ConfigId configId) {
+ this.applicationId = applicationId;
+ this.clusterId = clusterId;
+ this.serviceType = serviceType;
+ this.configId = configId;
+ }
+
+ public ApplicationId getApplicationId() {
+ return applicationId;
+ }
+
+ public ClusterId getClusterId() {
+ return clusterId;
+ }
+
+ public ServiceType getServiceType() {
+ return serviceType;
+ }
+
+ public ConfigId getConfigId() {
+ return configId;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ ServiceId serviceId = (ServiceId) o;
+ return Objects.equals(applicationId, serviceId.applicationId) &&
+ Objects.equals(clusterId, serviceId.clusterId) &&
+ Objects.equals(serviceType, serviceId.serviceType) &&
+ Objects.equals(configId, serviceId.configId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(applicationId, clusterId, serviceType, configId);
+ }
+
+ @Override
+ public String toString() {
+ return "ServiceId{" +
+ "applicationId=" + applicationId +
+ ", clusterId=" + clusterId +
+ ", serviceType=" + serviceType +
+ ", configId=" + configId +
+ '}';
+ }
+}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceMonitorImpl.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceMonitorImpl.java
index 97c4fdda0f3..bd8fd4a50e0 100644
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceMonitorImpl.java
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceMonitorImpl.java
@@ -14,10 +14,7 @@ import com.yahoo.vespa.service.monitor.ServiceMonitor;
import com.yahoo.vespa.service.monitor.internal.health.HealthMonitorManager;
import com.yahoo.vespa.service.monitor.internal.slobrok.SlobrokMonitorManagerImpl;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
-import java.util.stream.Collectors;
public class ServiceMonitorImpl implements ServiceMonitor {
private final ServiceModelCache serviceModelCache;
@@ -32,30 +29,20 @@ public class ServiceMonitorImpl implements ServiceMonitor {
Zone zone = superModelProvider.getZone();
ServiceMonitorMetrics metrics = new ServiceMonitorMetrics(metric, timer);
- UnionMonitorManager monitorManager = new UnionMonitorManager(
- slobrokMonitorManager,
- healthMonitorManager,
- configserverConfig);
+ DuperModel duperModel = new DuperModel(configserverConfig);
+ UnionMonitorManager monitorManager =
+ new UnionMonitorManager(slobrokMonitorManager, healthMonitorManager);
SuperModelListenerImpl superModelListener = new SuperModelListenerImpl(
monitorManager,
metrics,
- new ModelGenerator(toConfigServerList(configserverConfig)),
+ duperModel,
+ new ModelGenerator(),
zone);
superModelListener.start(superModelProvider);
serviceModelCache = new ServiceModelCache(superModelListener, timer);
}
- private List<String> toConfigServerList(ConfigserverConfig configserverConfig) {
- if (configserverConfig.multitenant()) {
- return configserverConfig.zookeeperserver().stream()
- .map(ConfigserverConfig.Zookeeperserver::hostname)
- .collect(Collectors.toList());
- }
-
- return Collections.emptyList();
- }
-
@Override
public Map<ApplicationInstanceReference, ApplicationInstance> getAllApplicationInstances() {
return serviceModelCache.get().getAllApplicationInstances();
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImpl.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImpl.java
index b2f3617131b..f509809c33d 100644
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImpl.java
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImpl.java
@@ -8,7 +8,9 @@ import com.yahoo.config.model.api.SuperModelProvider;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.Zone;
import com.yahoo.vespa.service.monitor.ServiceModel;
+import com.yahoo.vespa.service.monitor.ServiceStatusProvider;
+import java.util.List;
import java.util.function.Supplier;
import java.util.logging.Logger;
@@ -16,6 +18,7 @@ public class SuperModelListenerImpl implements SuperModelListener, Supplier<Serv
private static final Logger logger = Logger.getLogger(SuperModelListenerImpl.class.getName());
private final ServiceMonitorMetrics metrics;
+ private final DuperModel duperModel;
private final ModelGenerator modelGenerator;
private final Zone zone;
@@ -27,10 +30,12 @@ public class SuperModelListenerImpl implements SuperModelListener, Supplier<Serv
SuperModelListenerImpl(MonitorManager monitorManager,
ServiceMonitorMetrics metrics,
+ DuperModel duperModel,
ModelGenerator modelGenerator,
Zone zone) {
this.monitorManager = monitorManager;
this.metrics = metrics;
+ this.duperModel = duperModel;
this.modelGenerator = modelGenerator;
this.zone = zone;
}
@@ -41,8 +46,7 @@ public class SuperModelListenerImpl implements SuperModelListener, Supplier<Serv
// since applicationActivated()/applicationRemoved() may be called
// asynchronously even before snapshot() returns.
this.superModel = superModelProvider.snapshot(this);
- superModel.getAllApplicationInfos().stream().forEach(application ->
- monitorManager.applicationActivated(superModel, application));
+ duperModel.getApplicationInfos(superModel).forEach(monitorManager::applicationActivated);
}
}
@@ -50,7 +54,7 @@ public class SuperModelListenerImpl implements SuperModelListener, Supplier<Serv
public void applicationActivated(SuperModel superModel, ApplicationInfo application) {
synchronized (monitor) {
this.superModel = superModel;
- monitorManager.applicationActivated(superModel, application);
+ monitorManager.applicationActivated(application);
}
}
@@ -58,7 +62,7 @@ public class SuperModelListenerImpl implements SuperModelListener, Supplier<Serv
public void applicationRemoved(SuperModel superModel, ApplicationId id) {
synchronized (monitor) {
this.superModel = superModel;
- monitorManager.applicationRemoved(superModel, id);
+ monitorManager.applicationRemoved(id);
}
}
@@ -71,7 +75,9 @@ public class SuperModelListenerImpl implements SuperModelListener, Supplier<Serv
dummy(measurement);
// WARNING: The slobrok monitor manager may be out-of-sync with super model (no locking)
- return modelGenerator.toServiceModel(superModel, zone, monitorManager);
+ List<ApplicationInfo> applicationInfos = duperModel.getApplicationInfos(superModel);
+
+ return modelGenerator.toServiceModel(applicationInfos, zone, (ServiceStatusProvider) monitorManager);
}
}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManager.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManager.java
index 82d2043bd17..81cf6f2af5e 100644
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManager.java
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManager.java
@@ -1,16 +1,12 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.service.monitor.internal;
-import com.yahoo.cloud.config.ConfigserverConfig;
import com.yahoo.config.model.api.ApplicationInfo;
-import com.yahoo.config.model.api.SuperModel;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.vespa.applicationmodel.ClusterId;
import com.yahoo.vespa.applicationmodel.ConfigId;
import com.yahoo.vespa.applicationmodel.ServiceStatus;
import com.yahoo.vespa.applicationmodel.ServiceType;
-import com.yahoo.vespa.service.monitor.application.ConfigServerApplication;
-import com.yahoo.vespa.service.monitor.application.ZoneApplication;
import com.yahoo.vespa.service.monitor.internal.health.HealthMonitorManager;
import com.yahoo.vespa.service.monitor.internal.slobrok.SlobrokMonitorManagerImpl;
@@ -20,14 +16,11 @@ import com.yahoo.vespa.service.monitor.internal.slobrok.SlobrokMonitorManagerImp
public class UnionMonitorManager implements MonitorManager {
private final SlobrokMonitorManagerImpl slobrokMonitorManager;
private final HealthMonitorManager healthMonitorManager;
- private final ConfigserverConfig configserverConfig;
UnionMonitorManager(SlobrokMonitorManagerImpl slobrokMonitorManager,
- HealthMonitorManager healthMonitorManager,
- ConfigserverConfig configserverConfig) {
+ HealthMonitorManager healthMonitorManager) {
this.slobrokMonitorManager = slobrokMonitorManager;
this.healthMonitorManager = healthMonitorManager;
- this.configserverConfig = configserverConfig;
}
@Override
@@ -35,33 +28,25 @@ public class UnionMonitorManager implements MonitorManager {
ClusterId clusterId,
ServiceType serviceType,
ConfigId configId) {
-
- if (applicationId.equals(ConfigServerApplication.CONFIG_SERVER_APPLICATION.getApplicationId())) {
- // todo: use health
- return ServiceStatus.NOT_CHECKED;
+ // Trust the new health monitoring status if it actually monitors the particular service.
+ ServiceStatus status = healthMonitorManager.getStatus(applicationId, clusterId, serviceType, configId);
+ if (status != ServiceStatus.NOT_CHECKED) {
+ return status;
}
- MonitorManager monitorManager = useHealth(applicationId, clusterId, serviceType) ?
- healthMonitorManager :
- slobrokMonitorManager;
-
- return monitorManager.getStatus(applicationId, clusterId, serviceType, configId);
+ // fallback is the older slobrok
+ return slobrokMonitorManager.getStatus(applicationId, clusterId, serviceType, configId);
}
@Override
- public void applicationActivated(SuperModel superModel, ApplicationInfo application) {
- slobrokMonitorManager.applicationActivated(superModel, application);
- healthMonitorManager.applicationActivated(superModel, application);
+ public void applicationActivated(ApplicationInfo application) {
+ slobrokMonitorManager.applicationActivated(application);
+ healthMonitorManager.applicationActivated(application);
}
@Override
- public void applicationRemoved(SuperModel superModel, ApplicationId id) {
- slobrokMonitorManager.applicationRemoved(superModel, id);
- healthMonitorManager.applicationRemoved(superModel, id);
- }
-
- private boolean useHealth(ApplicationId applicationId, ClusterId clusterId, ServiceType serviceType) {
- return !configserverConfig.nodeAdminInContainer() &&
- ZoneApplication.isNodeAdminService(applicationId, clusterId, serviceType);
+ public void applicationRemoved(ApplicationId id) {
+ slobrokMonitorManager.applicationRemoved(id);
+ healthMonitorManager.applicationRemoved(id);
}
}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitor.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitor.java
new file mode 100644
index 00000000000..bd2658db8aa
--- /dev/null
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitor.java
@@ -0,0 +1,102 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal.health;
+
+import com.yahoo.config.model.api.ApplicationInfo;
+import com.yahoo.config.model.api.HostInfo;
+import com.yahoo.config.model.api.PortInfo;
+import com.yahoo.config.model.api.ServiceInfo;
+import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.config.provision.HostName;
+import com.yahoo.vespa.applicationmodel.ClusterId;
+import com.yahoo.vespa.applicationmodel.ConfigId;
+import com.yahoo.vespa.applicationmodel.ServiceStatus;
+import com.yahoo.vespa.applicationmodel.ServiceType;
+import com.yahoo.vespa.service.monitor.ServiceStatusProvider;
+import com.yahoo.vespa.service.monitor.application.ApplicationInstanceGenerator;
+import com.yahoo.vespa.service.monitor.internal.ServiceId;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Responsible for monitoring a whole application using /state/v1/health.
+ *
+ * @author hakon
+ */
+public class ApplicationHealthMonitor implements ServiceStatusProvider, AutoCloseable {
+ public static final String PORT_TAG_STATE = "STATE";
+ public static final String PORT_TAG_HTTP = "HTTP";
+ /** Port tags implying /state/v1/health is served */
+ public static final List<String> PORT_TAGS_HEALTH =
+ Collections.unmodifiableList(Arrays.asList(PORT_TAG_HTTP, PORT_TAG_STATE));
+
+ private final Map<ServiceId, HealthMonitor> healthMonitors;
+
+ public static ApplicationHealthMonitor startMonitoring(ApplicationInfo application) {
+ return new ApplicationHealthMonitor(makeHealthMonitors(application));
+ }
+
+ private ApplicationHealthMonitor(Map<ServiceId, HealthMonitor> healthMonitors) {
+ this.healthMonitors = healthMonitors;
+ }
+
+ @Override
+ public ServiceStatus getStatus(ApplicationId applicationId,
+ ClusterId clusterId,
+ ServiceType serviceType,
+ ConfigId configId) {
+ ServiceId serviceId = new ServiceId(applicationId, clusterId, serviceType, configId);
+ HealthMonitor monitor = healthMonitors.get(serviceId);
+ if (monitor == null) {
+ return ServiceStatus.NOT_CHECKED;
+ }
+
+ return monitor.getStatus();
+ }
+
+ @Override
+ public void close() {
+ healthMonitors.values().forEach(HealthMonitor::close);
+ healthMonitors.clear();
+ }
+
+ private static Map<ServiceId, HealthMonitor> makeHealthMonitors(ApplicationInfo application) {
+ Map<ServiceId, HealthMonitor> healthMonitors = new HashMap<>();
+ for (HostInfo hostInfo : application.getModel().getHosts()) {
+ for (ServiceInfo serviceInfo : hostInfo.getServices()) {
+ for (PortInfo portInfo : serviceInfo.getPorts()) {
+ maybeCreateHealthMonitor(
+ application,
+ hostInfo,
+ serviceInfo,
+ portInfo)
+ .ifPresent(healthMonitor -> healthMonitors.put(
+ ApplicationInstanceGenerator.getServiceId(application, serviceInfo),
+ healthMonitor));
+ }
+ }
+ }
+ return healthMonitors;
+ }
+
+ private static Optional<HealthMonitor> maybeCreateHealthMonitor(
+ ApplicationInfo applicationInfo,
+ HostInfo hostInfo,
+ ServiceInfo serviceInfo,
+ PortInfo portInfo) {
+ if (portInfo.getTags().containsAll(PORT_TAGS_HEALTH)) {
+ HostName hostname = HostName.from(hostInfo.getHostname());
+ HealthEndpoint endpoint = HealthEndpoint.forHttp(hostname, portInfo.getPort());
+ // todo: make HealthMonitor
+ // HealthMonitor healthMonitor = new HealthMonitor(endpoint);
+ // healthMonitor.startMonitoring();
+ return Optional.empty();
+ }
+
+ return Optional.empty();
+ }
+}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthClient.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthClient.java
new file mode 100644
index 00000000000..43a02a385be
--- /dev/null
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthClient.java
@@ -0,0 +1,139 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal.health;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.vespa.athenz.api.AthenzService;
+import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider;
+import org.apache.http.HttpEntity;
+import org.apache.http.HttpResponse;
+import org.apache.http.client.config.RequestConfig;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.config.Registry;
+import org.apache.http.config.RegistryBuilder;
+import org.apache.http.conn.ConnectionKeepAliveStrategy;
+import org.apache.http.conn.HttpClientConnectionManager;
+import org.apache.http.conn.socket.ConnectionSocketFactory;
+import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.impl.client.DefaultConnectionKeepAliveStrategy;
+import org.apache.http.impl.client.HttpClients;
+import org.apache.http.impl.conn.BasicHttpClientConnectionManager;
+import org.apache.http.protocol.HttpContext;
+import org.apache.http.util.EntityUtils;
+
+import javax.net.ssl.SSLContext;
+
+/**
+ * @author hakon
+ */
+public class HealthClient implements AutoCloseable, ServiceIdentityProvider.Listener {
+ private static final ObjectMapper mapper = new ObjectMapper();
+ private static final long MAX_CONTENT_LENGTH = 1L << 20; // 1 MB
+ private static final int DEFAULT_TIMEOUT_MILLIS = 1_000;
+
+ private static final ConnectionKeepAliveStrategy KEEP_ALIVE_STRATEGY =
+ new DefaultConnectionKeepAliveStrategy() {
+ @Override
+ public long getKeepAliveDuration(HttpResponse response, HttpContext context) {
+ long keepAlive = super.getKeepAliveDuration(response, context);
+ if (keepAlive == -1) {
+ // Keep connections alive 60 seconds if a keep-alive value
+ // has not be explicitly set by the server
+ keepAlive = 60000;
+ }
+ return keepAlive;
+ }
+ };
+
+ private final HealthEndpoint endpoint;
+
+ private volatile CloseableHttpClient httpClient;
+
+ public HealthClient(HealthEndpoint endpoint) {
+ this.endpoint = endpoint;
+ }
+
+ public void start() {
+ endpoint.getServiceIdentityProvider().ifPresent(provider -> {
+ onCredentialsUpdate(provider.getIdentitySslContext(), null);
+ provider.addIdentityListener(this);
+ });
+ }
+
+ @Override
+ public void onCredentialsUpdate(SSLContext sslContext, AthenzService ignored) {
+ SSLConnectionSocketFactory socketFactory =
+ new SSLConnectionSocketFactory(sslContext, endpoint.getHostnameVerifier().orElse(null));
+
+ Registry<ConnectionSocketFactory> registry = RegistryBuilder.<ConnectionSocketFactory>create()
+ .register("https", socketFactory)
+ .build();
+
+ HttpClientConnectionManager connectionManager = new BasicHttpClientConnectionManager(registry);
+
+ RequestConfig requestConfig = RequestConfig.custom()
+ .setConnectTimeout(DEFAULT_TIMEOUT_MILLIS) // establishment of connection
+ .setConnectionRequestTimeout(DEFAULT_TIMEOUT_MILLIS) // connection from connection manager
+ .setSocketTimeout(DEFAULT_TIMEOUT_MILLIS) // waiting for data
+ .build();
+
+ this.httpClient = HttpClients.custom()
+ .setKeepAliveStrategy(KEEP_ALIVE_STRATEGY)
+ .setConnectionManager(connectionManager)
+ .disableAutomaticRetries()
+ .setDefaultRequestConfig(requestConfig)
+ .build();
+ }
+
+ public HealthInfo getHealthInfo() {
+ try {
+ return probeHealth();
+ } catch (Exception e) {
+ return HealthInfo.fromException(e);
+ }
+ }
+
+ @Override
+ public void close() {
+ endpoint.getServiceIdentityProvider().ifPresent(provider -> provider.removeIdentityListener(this));
+
+ try {
+ httpClient.close();
+ } catch (Exception e) {
+ // ignore
+ }
+ httpClient = null;
+ }
+
+ private HealthInfo probeHealth() throws Exception {
+ HttpGet httpget = new HttpGet(endpoint.getStateV1HealthUrl().toString());
+ CloseableHttpResponse httpResponse;
+
+ CloseableHttpClient httpClient = this.httpClient;
+ if (httpClient == null) {
+ throw new IllegalStateException("HTTP client has closed");
+ }
+
+ httpResponse = httpClient.execute(httpget);
+
+ int httpStatusCode = httpResponse.getStatusLine().getStatusCode();
+ if (httpStatusCode < 200 || httpStatusCode >= 300) {
+ return HealthInfo.fromBadHttpStatusCode(httpStatusCode);
+ }
+
+ HttpEntity bodyEntity = httpResponse.getEntity();
+ long contentLength = bodyEntity.getContentLength();
+ if (contentLength > MAX_CONTENT_LENGTH) {
+ throw new IllegalArgumentException("Content too long: " + contentLength + " bytes");
+ }
+ String body = EntityUtils.toString(bodyEntity);
+ HealthResponse healthResponse = mapper.readValue(body, HealthResponse.class);
+
+ if (healthResponse.status == null || healthResponse.status.code == null) {
+ return HealthInfo.fromHealthStatusCode(HealthResponse.Status.DEFAULT_STATUS);
+ } else {
+ return HealthInfo.fromHealthStatusCode(healthResponse.status.code);
+ }
+ }
+}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthEndpoint.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthEndpoint.java
new file mode 100644
index 00000000000..e9d17a9ab70
--- /dev/null
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthEndpoint.java
@@ -0,0 +1,57 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal.health;
+
+import com.yahoo.config.provision.HostName;
+import com.yahoo.vespa.athenz.api.AthenzIdentity;
+import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider;
+import com.yahoo.vespa.athenz.tls.AthenzIdentityVerifier;
+
+import javax.net.ssl.HostnameVerifier;
+import java.net.URL;
+import java.util.Collections;
+import java.util.Optional;
+
+import static com.yahoo.yolean.Exceptions.uncheck;
+
+/**
+ * @author hakon
+ */
+class HealthEndpoint {
+ private final URL url;
+ private final Optional<HostnameVerifier> hostnameVerifier;
+ private final Optional<ServiceIdentityProvider> serviceIdentityProvider;
+
+ static HealthEndpoint forHttp(HostName hostname, int port) {
+ URL url = uncheck(() -> new URL("http", hostname.value(), port, "/state/v1/health"));
+ return new HealthEndpoint(url, Optional.empty(), Optional.empty());
+ }
+
+ static HealthEndpoint forHttps(HostName hostname,
+ int port,
+ ServiceIdentityProvider serviceIdentityProvider,
+ AthenzIdentity remoteIdentity) {
+ URL url = uncheck(() -> new URL("https", hostname.value(), port, "/state/v1/health"));
+ HostnameVerifier peerVerifier = new AthenzIdentityVerifier(Collections.singleton(remoteIdentity));
+ return new HealthEndpoint(url, Optional.of(serviceIdentityProvider), Optional.of(peerVerifier));
+ }
+
+ private HealthEndpoint(URL url,
+ Optional<ServiceIdentityProvider> serviceIdentityProvider,
+ Optional<HostnameVerifier> hostnameVerifier) {
+ this.url = url;
+ this.serviceIdentityProvider = serviceIdentityProvider;
+ this.hostnameVerifier = hostnameVerifier;
+ }
+
+ public URL getStateV1HealthUrl() {
+ return url;
+ }
+
+ public Optional<ServiceIdentityProvider> getServiceIdentityProvider() {
+ return serviceIdentityProvider;
+ }
+
+ public Optional<HostnameVerifier> getHostnameVerifier() {
+ return hostnameVerifier;
+ }
+}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthInfo.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthInfo.java
new file mode 100644
index 00000000000..a3fe3cb3106
--- /dev/null
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthInfo.java
@@ -0,0 +1,75 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal.health;
+
+import com.yahoo.vespa.applicationmodel.ServiceStatus;
+import com.yahoo.yolean.Exceptions;
+
+import java.time.Instant;
+import java.util.Optional;
+import java.util.OptionalInt;
+
+/**
+ * The result of a health lookup.
+ *
+ * @author hakon
+ */
+public class HealthInfo {
+ public static final String UP_STATUS_CODE = "up";
+
+ private final Optional<Exception> exception;
+ private final OptionalInt httpStatusCode;
+ private final Optional<String> healthStatusCode;
+ private final Instant time;
+
+ static HealthInfo fromException(Exception exception) {
+ return new HealthInfo(Optional.of(exception), OptionalInt.empty(), Optional.empty());
+ }
+
+ static HealthInfo fromBadHttpStatusCode(int httpStatusCode) {
+ return new HealthInfo(Optional.empty(), OptionalInt.of(httpStatusCode), Optional.empty());
+ }
+
+ static HealthInfo fromHealthStatusCode(String healthStatusCode) {
+ return new HealthInfo(Optional.empty(), OptionalInt.empty(), Optional.of(healthStatusCode));
+ }
+
+ static HealthInfo empty() {
+ return new HealthInfo(Optional.empty(), OptionalInt.empty(), Optional.empty());
+ }
+
+ private HealthInfo(Optional<Exception> exception,
+ OptionalInt httpStatusCode,
+ Optional<String> healthStatusCode) {
+ this.exception = exception;
+ this.httpStatusCode = httpStatusCode;
+ this.healthStatusCode = healthStatusCode;
+ this.time = Instant.now();
+ }
+
+ public boolean isHealthy() {
+ return healthStatusCode.map(UP_STATUS_CODE::equals).orElse(false);
+ }
+
+ public ServiceStatus toSerivceStatus() {
+ return isHealthy() ? ServiceStatus.UP : ServiceStatus.DOWN;
+ }
+
+ public Instant time() {
+ return time;
+ }
+
+ @Override
+ public String toString() {
+ if (isHealthy()) {
+ return UP_STATUS_CODE;
+ } else if (healthStatusCode.isPresent()) {
+ return "Bad health status code '" + healthStatusCode.get() + "'";
+ } else if (exception.isPresent()) {
+ return Exceptions.toMessageString(exception.get());
+ } else if (httpStatusCode.isPresent()) {
+ return "Bad HTTP response status code " + httpStatusCode.getAsInt();
+ } else {
+ return "No health info available";
+ }
+ }
+}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitor.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitor.java
new file mode 100644
index 00000000000..fd809b32918
--- /dev/null
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitor.java
@@ -0,0 +1,73 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal.health;
+
+import com.yahoo.log.LogLevel;
+import com.yahoo.vespa.applicationmodel.ServiceStatus;
+
+import java.time.Duration;
+import java.util.Random;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.logging.Logger;
+
+/**
+ * Used to monitor the health of a single URL endpoint.
+ *
+ * @author hakon
+ */
+public class HealthMonitor implements AutoCloseable {
+ private static final Logger logger = Logger.getLogger(HealthMonitor.class.getName());
+ private static final Duration DELAY = Duration.ofSeconds(20);
+ // About 'static': Javadoc says "Instances of java.util.Random are threadsafe."
+ private static final Random random = new Random();
+
+ private final ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(1);
+ private final HealthClient healthClient;
+
+ private volatile HealthInfo lastHealthInfo = HealthInfo.empty();
+
+ public HealthMonitor(HealthEndpoint stateV1HealthEndpoint) {
+ this.healthClient = new HealthClient(stateV1HealthEndpoint);
+ }
+
+ /** For testing. */
+ HealthMonitor(HealthClient healthClient) {
+ this.healthClient = healthClient;
+ }
+
+ public void startMonitoring() {
+ healthClient.start();
+ executor.scheduleWithFixedDelay(
+ this::updateSynchronously,
+ initialDelayInSeconds(DELAY.getSeconds()),
+ DELAY.getSeconds(),
+ TimeUnit.SECONDS);
+ }
+
+ public ServiceStatus getStatus() {
+ // todo: return lastHealthInfo.toServiceStatus();
+ return ServiceStatus.NOT_CHECKED;
+ }
+
+ @Override
+ public void close() {
+ executor.shutdown();
+
+ try {
+ executor.awaitTermination(2, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ logger.log(LogLevel.INFO, "Interrupted while waiting for health monitor termination: " +
+ e.getMessage());
+ }
+
+ healthClient.close();
+ }
+
+ private long initialDelayInSeconds(long maxInitialDelayInSeconds) {
+ return random.nextLong() % maxInitialDelayInSeconds;
+ }
+
+ private void updateSynchronously() {
+ lastHealthInfo = healthClient.getHealthInfo();
+ }
+}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManager.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManager.java
index 5a4b7251ae2..473ef5e3a94 100644
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManager.java
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManager.java
@@ -2,8 +2,8 @@
package com.yahoo.vespa.service.monitor.internal.health;
import com.google.inject.Inject;
+import com.yahoo.cloud.config.ConfigserverConfig;
import com.yahoo.config.model.api.ApplicationInfo;
-import com.yahoo.config.model.api.SuperModel;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.vespa.applicationmodel.ClusterId;
import com.yahoo.vespa.applicationmodel.ConfigId;
@@ -12,19 +12,38 @@ import com.yahoo.vespa.applicationmodel.ServiceType;
import com.yahoo.vespa.service.monitor.application.ZoneApplication;
import com.yahoo.vespa.service.monitor.internal.MonitorManager;
+import java.util.HashMap;
+import java.util.Map;
+
/**
* @author hakon
*/
public class HealthMonitorManager implements MonitorManager {
+ private final Map<ApplicationId, ApplicationHealthMonitor> healthMonitors = new HashMap<>();
+ private final ConfigserverConfig configserverConfig;
+
@Inject
- public HealthMonitorManager() {}
+ public HealthMonitorManager(ConfigserverConfig configserverConfig) {
+ this.configserverConfig = configserverConfig;
+ }
@Override
- public void applicationActivated(SuperModel superModel, ApplicationInfo application) {
+ public void applicationActivated(ApplicationInfo application) {
+ if (applicationMonitored(application.getApplicationId())) {
+ ApplicationHealthMonitor monitor =
+ ApplicationHealthMonitor.startMonitoring(application);
+ healthMonitors.put(application.getApplicationId(), monitor);
+ }
}
@Override
- public void applicationRemoved(SuperModel superModel, ApplicationId id) {
+ public void applicationRemoved(ApplicationId id) {
+ if (applicationMonitored(id)) {
+ ApplicationHealthMonitor monitor = healthMonitors.remove(id);
+ if (monitor != null) {
+ monitor.close();
+ }
+ }
}
@Override
@@ -32,13 +51,18 @@ public class HealthMonitorManager implements MonitorManager {
ClusterId clusterId,
ServiceType serviceType,
ConfigId configId) {
- // TODO: Do proper health check
- if (ZoneApplication.isNodeAdminService(applicationId, clusterId, serviceType)) {
+ if (!configserverConfig.nodeAdminInContainer() &&
+ ZoneApplication.isNodeAdminService(applicationId, clusterId, serviceType)) {
+ // If node admin doesn't run in a JDisc container, it must be monitored with health.
+ // TODO: Do proper health check
return ServiceStatus.UP;
}
- throw new IllegalArgumentException("Health monitoring not implemented for application " +
- applicationId.toShortString() + ", cluster " + clusterId.s() + ", serviceType " +
- serviceType);
+ return ServiceStatus.NOT_CHECKED;
+ }
+
+ private boolean applicationMonitored(ApplicationId id) {
+ // todo: health-check config server
+ return false;
}
}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthResponse.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthResponse.java
new file mode 100644
index 00000000000..574523ad564
--- /dev/null
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthResponse.java
@@ -0,0 +1,35 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal.health;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.yahoo.text.JSON;
+
+/**
+ * Response entity from /state/v1/health
+ *
+ * @author hakon
+ */
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class HealthResponse {
+ @JsonProperty("status")
+ public Status status = new Status();
+
+ @JsonIgnoreProperties(ignoreUnknown = true)
+ public static class Status {
+ public static final String DEFAULT_STATUS = "down";
+
+ @JsonProperty("code")
+ public String code = DEFAULT_STATUS;
+
+ @Override
+ public String toString() {
+ return "{ \"code\": \"" + JSON.escape(code) + "\" }";
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "{ \"status\": " + status.toString() + " }";
+ }
+}
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImpl.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImpl.java
index aaaab22e742..68958c94dfd 100644
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImpl.java
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImpl.java
@@ -3,8 +3,6 @@ package com.yahoo.vespa.service.monitor.internal.slobrok;
import com.google.inject.Inject;
import com.yahoo.config.model.api.ApplicationInfo;
-import com.yahoo.config.model.api.SuperModel;
-import com.yahoo.config.model.api.SuperModelListener;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.jrt.slobrok.api.Mirror;
import com.yahoo.log.LogLevel;
@@ -13,6 +11,7 @@ import com.yahoo.vespa.applicationmodel.ConfigId;
import com.yahoo.vespa.applicationmodel.ServiceStatus;
import com.yahoo.vespa.applicationmodel.ServiceType;
import com.yahoo.vespa.service.monitor.SlobrokApi;
+import com.yahoo.vespa.service.monitor.application.ConfigServerApplication;
import com.yahoo.vespa.service.monitor.internal.MonitorManager;
import java.util.HashMap;
@@ -21,7 +20,7 @@ import java.util.Optional;
import java.util.function.Supplier;
import java.util.logging.Logger;
-public class SlobrokMonitorManagerImpl implements SuperModelListener, SlobrokApi, MonitorManager {
+public class SlobrokMonitorManagerImpl implements SlobrokApi, MonitorManager {
private static final Logger logger =
Logger.getLogger(SlobrokMonitorManagerImpl.class.getName());
@@ -40,7 +39,11 @@ public class SlobrokMonitorManagerImpl implements SuperModelListener, SlobrokApi
}
@Override
- public void applicationActivated(SuperModel superModel, ApplicationInfo application) {
+ public void applicationActivated(ApplicationInfo application) {
+ if (!applicationMonitoredWithSlobrok(application.getApplicationId())) {
+ return;
+ }
+
synchronized (monitor) {
SlobrokMonitor slobrokMonitor = slobrokMonitors.computeIfAbsent(
application.getApplicationId(),
@@ -50,7 +53,11 @@ public class SlobrokMonitorManagerImpl implements SuperModelListener, SlobrokApi
}
@Override
- public void applicationRemoved(SuperModel superModel, ApplicationId id) {
+ public void applicationRemoved(ApplicationId id) {
+ if (!applicationMonitoredWithSlobrok(id)) {
+ return;
+ }
+
synchronized (monitor) {
SlobrokMonitor slobrokMonitor = slobrokMonitors.remove(id);
if (slobrokMonitor == null) {
@@ -79,6 +86,10 @@ public class SlobrokMonitorManagerImpl implements SuperModelListener, SlobrokApi
ClusterId clusterId,
ServiceType serviceType,
ConfigId configId) {
+ if (!applicationMonitoredWithSlobrok(applicationId)) {
+ return ServiceStatus.NOT_CHECKED;
+ }
+
Optional<String> slobrokServiceName = findSlobrokServiceName(serviceType, configId);
if (slobrokServiceName.isPresent()) {
synchronized (monitor) {
@@ -95,6 +106,14 @@ public class SlobrokMonitorManagerImpl implements SuperModelListener, SlobrokApi
}
}
+ private boolean applicationMonitoredWithSlobrok(ApplicationId applicationId) {
+ if (applicationId.equals(ConfigServerApplication.CONFIG_SERVER_APPLICATION.getApplicationId())) {
+ return false;
+ }
+
+ return true;
+ }
+
/**
* Get the Slobrok service name of the service, or empty if the service
* is not registered with Slobrok.
diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/application/ConfigServerAppGeneratorTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGeneratorTest.java
index 58f99786017..899cc59bb34 100644
--- a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/application/ConfigServerAppGeneratorTest.java
+++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGeneratorTest.java
@@ -1,22 +1,27 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.service.monitor.application;
+import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.config.model.api.ApplicationInfo;
+import com.yahoo.config.provision.Zone;
import com.yahoo.vespa.applicationmodel.ApplicationInstance;
import com.yahoo.vespa.applicationmodel.ServiceStatus;
import com.yahoo.vespa.service.monitor.ServiceStatusProvider;
+import com.yahoo.vespa.service.monitor.internal.ConfigserverUtil;
import org.junit.Test;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
+import static com.yahoo.vespa.service.monitor.application.ConfigServerApplication.CONFIG_SERVER_APPLICATION;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
-public class ConfigServerAppGeneratorTest {
+public class ApplicationInstanceGeneratorTest {
private static final String configServer1 = "cfg1.yahoo.com";
private static final String configServer2 = "cfg2.yahoo.com";
private static final String configServer3 = "cfg3.yahoo.com";
@@ -28,9 +33,17 @@ public class ConfigServerAppGeneratorTest {
private final ServiceStatusProvider statusProvider = mock(ServiceStatusProvider.class);
@Test
- public void toApplicationInstance() throws Exception {
+ public void toApplicationInstance() {
when(statusProvider.getStatus(any(), any(), any(), any())).thenReturn(ServiceStatus.NOT_CHECKED);
- ApplicationInstance applicationInstance = new ConfigServerAppGenerator(configServerList)
+ ConfigserverConfig config = ConfigserverUtil.create(
+ true,
+ true,
+ configServer1,
+ configServer2,
+ configServer3);
+ Zone zone = mock(Zone.class);
+ ApplicationInfo configServer = CONFIG_SERVER_APPLICATION.makeApplicationInfo(config);
+ ApplicationInstance applicationInstance = new ApplicationInstanceGenerator(configServer, zone)
.makeApplicationInstance(statusProvider);
assertEquals(
diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ConfigserverUtil.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ConfigserverUtil.java
new file mode 100644
index 00000000000..85df02949a6
--- /dev/null
+++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ConfigserverUtil.java
@@ -0,0 +1,52 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal;
+
+import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.config.model.api.ApplicationInfo;
+import com.yahoo.vespa.service.monitor.application.ConfigServerApplication;
+
+/**
+ * @author hakon
+ */
+public class ConfigserverUtil {
+ /** Create a ConfigserverConfig with the given settings. */
+ public static ConfigserverConfig create(
+ boolean nodeAdminInContainer,
+ boolean multitenant,
+ String configServerHostname1,
+ String configServerHostname2,
+ String configServerHostname3) {
+ return new ConfigserverConfig(
+ new ConfigserverConfig.Builder()
+ .nodeAdminInContainer(nodeAdminInContainer)
+ .multitenant(multitenant)
+ .zookeeperserver(new ConfigserverConfig.Zookeeperserver.Builder().hostname(configServerHostname1).port(1))
+ .zookeeperserver(new ConfigserverConfig.Zookeeperserver.Builder().hostname(configServerHostname2).port(2))
+ .zookeeperserver(new ConfigserverConfig.Zookeeperserver.Builder().hostname(configServerHostname3).port(3)));
+ }
+
+ public static ConfigserverConfig createExampleConfigserverConfig() {
+ return create(true, true, "cfg1", "cfg2", "cfg3");
+ }
+
+ public static ConfigserverConfig createExampleConfigserverConfig(boolean nodeAdminInContainer,
+ boolean multitenant) {
+ return create(nodeAdminInContainer, multitenant, "cfg1", "cfg2", "cfg3");
+ }
+
+ public static ApplicationInfo makeConfigServerApplicationInfo(
+ String configServerHostname1,
+ String configServerHostname2,
+ String configServerHostname3) {
+ return ConfigServerApplication.CONFIG_SERVER_APPLICATION.makeApplicationInfo(create(
+ true,
+ true,
+ configServerHostname1,
+ configServerHostname2,
+ configServerHostname3));
+ }
+
+ public static ApplicationInfo makeExampleConfigServer() {
+ return makeConfigServerApplicationInfo("cfg1", "cfg2", "cfg3");
+ }
+}
diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/DuperModelTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/DuperModelTest.java
new file mode 100644
index 00000000000..c9d19d0ccd9
--- /dev/null
+++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/DuperModelTest.java
@@ -0,0 +1,53 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal;
+
+import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.config.model.api.ApplicationInfo;
+import com.yahoo.config.model.api.SuperModel;
+import com.yahoo.vespa.applicationmodel.ServiceStatus;
+import com.yahoo.vespa.service.monitor.ServiceStatusProvider;
+import com.yahoo.vespa.service.monitor.application.ConfigServerApplication;
+import org.junit.Test;
+
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertSame;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author hakon
+ */
+public class DuperModelTest {
+ private final ServiceStatusProvider statusProvider = mock(ServiceStatusProvider.class);
+
+ @Test
+ public void toApplicationInstance() {
+ when(statusProvider.getStatus(any(), any(), any(), any())).thenReturn(ServiceStatus.NOT_CHECKED);
+ ConfigserverConfig config = ConfigserverUtil.createExampleConfigserverConfig();
+ DuperModel duperModel = new DuperModel(config);
+ SuperModel superModel = mock(SuperModel.class);
+ ApplicationInfo superModelApplicationInfo = mock(ApplicationInfo.class);
+ when(superModel.getAllApplicationInfos()).thenReturn(Collections.singletonList(superModelApplicationInfo));
+ List<ApplicationInfo> applicationInfos = duperModel.getApplicationInfos(superModel);
+ assertEquals(2, applicationInfos.size());
+ assertEquals(ConfigServerApplication.CONFIG_SERVER_APPLICATION.getApplicationId(), applicationInfos.get(0).getApplicationId());
+ assertSame(superModelApplicationInfo, applicationInfos.get(1));
+ }
+
+ @Test
+ public void toApplicationInstanceInSingleTenantMode() {
+ when(statusProvider.getStatus(any(), any(), any(), any())).thenReturn(ServiceStatus.NOT_CHECKED);
+ ConfigserverConfig config = ConfigserverUtil.createExampleConfigserverConfig(true, false);
+ DuperModel duperModel = new DuperModel(config);
+ SuperModel superModel = mock(SuperModel.class);
+ ApplicationInfo superModelApplicationInfo = mock(ApplicationInfo.class);
+ when(superModel.getAllApplicationInfos()).thenReturn(Collections.singletonList(superModelApplicationInfo));
+ List<ApplicationInfo> applicationInfos = duperModel.getApplicationInfos(superModel);
+ assertEquals(1, applicationInfos.size());
+ assertSame(superModelApplicationInfo, applicationInfos.get(0));
+ }
+}
diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ModelGeneratorTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ModelGeneratorTest.java
index a21691ee4d0..5a57451a298 100644
--- a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ModelGeneratorTest.java
+++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ModelGeneratorTest.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.service.monitor.internal;
+import com.yahoo.cloud.config.ConfigserverConfig;
import com.yahoo.config.model.api.SuperModel;
import com.yahoo.config.provision.Environment;
import com.yahoo.config.provision.RegionName;
@@ -15,13 +16,9 @@ import com.yahoo.vespa.service.monitor.application.ConfigServerApplication;
import com.yahoo.vespa.service.monitor.internal.slobrok.SlobrokMonitorManagerImpl;
import org.junit.Test;
-import java.util.Collections;
import java.util.Iterator;
-import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
@@ -35,13 +32,12 @@ public class ModelGeneratorTest {
private final int PORT = 2;
@Test
- public void toApplicationModelWithConfigServerApplication() throws Exception {
- SuperModel superModel =
- ExampleModel.createExampleSuperModelWithOneRpcPort(HOSTNAME, PORT);
+ public void toApplicationModel() throws Exception {
+ SuperModel superModel = ExampleModel.createExampleSuperModelWithOneRpcPort(HOSTNAME, PORT);
- List<String> configServerHosts = Stream.of("cfg1", "cfg2", "cfg3")
- .collect(Collectors.toList());
- ModelGenerator modelGenerator = new ModelGenerator(configServerHosts);
+ ConfigserverConfig config = ConfigserverUtil.createExampleConfigserverConfig();
+ DuperModel duperModel = new DuperModel(config);
+ ModelGenerator modelGenerator = new ModelGenerator();
Zone zone = new Zone(Environment.from(ENVIRONMENT), RegionName.from(REGION));
@@ -51,7 +47,7 @@ public class ModelGeneratorTest {
ServiceModel serviceModel =
modelGenerator.toServiceModel(
- superModel,
+ duperModel.getApplicationInfos(superModel),
zone,
slobrokMonitorManager);
@@ -78,32 +74,6 @@ public class ModelGeneratorTest {
}
}
- @Test
- public void toApplicationModel() throws Exception {
- SuperModel superModel =
- ExampleModel.createExampleSuperModelWithOneRpcPort(HOSTNAME, PORT);
- ModelGenerator modelGenerator = new ModelGenerator(Collections.emptyList());
-
- Zone zone = new Zone(Environment.from(ENVIRONMENT), RegionName.from(REGION));
-
- SlobrokMonitorManagerImpl slobrokMonitorManager = mock(SlobrokMonitorManagerImpl.class);
- when(slobrokMonitorManager.getStatus(any(), any(), any(), any()))
- .thenReturn(ServiceStatus.UP);
-
- ServiceModel serviceModel =
- modelGenerator.toServiceModel(
- superModel,
- zone,
- slobrokMonitorManager);
-
- Map<ApplicationInstanceReference,
- ApplicationInstance> applicationInstances =
- serviceModel.getAllApplicationInstances();
-
- assertEquals(1, applicationInstances.size());
- verifyOtherApplication(applicationInstances.values().iterator().next());
- }
-
private void verifyOtherApplication(ApplicationInstance applicationInstance) {
assertEquals(String.format("%s:%s:%s:%s:%s",
ExampleModel.TENANT,
diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImplTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImplTest.java
index 83bad0ddb2a..eb6d6d583f7 100644
--- a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImplTest.java
+++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImplTest.java
@@ -14,6 +14,7 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -22,11 +23,13 @@ public class SuperModelListenerImplTest {
public void sanityCheck() {
SlobrokMonitorManagerImpl slobrokMonitorManager = mock(SlobrokMonitorManagerImpl.class);
ServiceMonitorMetrics metrics = mock(ServiceMonitorMetrics.class);
+ DuperModel duperModel = mock(DuperModel.class);
ModelGenerator modelGenerator = mock(ModelGenerator.class);
Zone zone = mock(Zone.class);
SuperModelListenerImpl listener = new SuperModelListenerImpl(
slobrokMonitorManager,
metrics,
+ duperModel,
modelGenerator,
zone);
@@ -38,13 +41,15 @@ public class SuperModelListenerImplTest {
ApplicationInfo application2 = mock(ApplicationInfo.class);
List<ApplicationInfo> applications = Stream.of(application1, application2)
.collect(Collectors.toList());
- when(superModel.getAllApplicationInfos()).thenReturn(applications);
+ when(duperModel.getApplicationInfos(superModel)).thenReturn(applications);
listener.start(superModelProvider);
- verify(slobrokMonitorManager).applicationActivated(superModel, application1);
- verify(slobrokMonitorManager).applicationActivated(superModel, application2);
+ verify(duperModel, times(1)).getApplicationInfos(superModel);
+ verify(slobrokMonitorManager).applicationActivated(application1);
+ verify(slobrokMonitorManager).applicationActivated(application2);
ServiceModel serviceModel = listener.get();
- verify(modelGenerator).toServiceModel(superModel, zone, slobrokMonitorManager);
+ verify(duperModel, times(2)).getApplicationInfos(superModel);
+ verify(modelGenerator).toServiceModel(applications, zone, slobrokMonitorManager);
}
} \ No newline at end of file
diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManagerTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManagerTest.java
index b7c3ed8e1e1..79916e43712 100644
--- a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManagerTest.java
+++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManagerTest.java
@@ -1,95 +1,44 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.service.monitor.internal;
-import com.yahoo.cloud.config.ConfigserverConfig;
-import com.yahoo.config.provision.ApplicationId;
-import com.yahoo.vespa.applicationmodel.ClusterId;
import com.yahoo.vespa.applicationmodel.ConfigId;
-import com.yahoo.vespa.applicationmodel.ServiceType;
+import com.yahoo.vespa.applicationmodel.ServiceStatus;
import com.yahoo.vespa.service.monitor.internal.health.HealthMonitorManager;
import com.yahoo.vespa.service.monitor.internal.slobrok.SlobrokMonitorManagerImpl;
import org.junit.Test;
import static com.yahoo.vespa.applicationmodel.ClusterId.NODE_ADMIN;
+import static com.yahoo.vespa.applicationmodel.ServiceStatus.*;
+import static com.yahoo.vespa.applicationmodel.ServiceStatus.NOT_CHECKED;
+import static com.yahoo.vespa.applicationmodel.ServiceStatus.UP;
import static com.yahoo.vespa.applicationmodel.ServiceType.CONTAINER;
import static com.yahoo.vespa.service.monitor.application.ZoneApplication.ZONE_APPLICATION_ID;
+import static org.junit.Assert.assertSame;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
public class UnionMonitorManagerTest {
- @Test
- public void nodeAdminInContainer() {
- testWith(
- true,
- ZONE_APPLICATION_ID,
- NODE_ADMIN,
- CONTAINER,
- 1,
- 0);
- }
-
- @Test
- public void nodeAdminOutsideContainer() {
- boolean inContainer = false;
-
- // When nodeAdminInContainer is set, then only the node admin cluster should use health
- testWith(
- inContainer,
- ZONE_APPLICATION_ID,
- NODE_ADMIN,
- CONTAINER,
- 0,
- 1);
-
- testWith(
- inContainer,
- ApplicationId.fromSerializedForm("a:b:default"),
- NODE_ADMIN,
- CONTAINER,
- 1,
- 0);
+ private final SlobrokMonitorManagerImpl slobrokMonitorManager = mock(SlobrokMonitorManagerImpl.class);
+ private final HealthMonitorManager healthMonitorManager = mock(HealthMonitorManager.class);
- testWith(
- inContainer,
- ZONE_APPLICATION_ID,
- new ClusterId("foo"),
- CONTAINER,
- 1,
- 0);
+ private final UnionMonitorManager manager = new UnionMonitorManager(
+ slobrokMonitorManager,
+ healthMonitorManager);
- testWith(
- inContainer,
- ZONE_APPLICATION_ID,
- NODE_ADMIN,
- new ServiceType("foo"),
- 1,
- 0);
+ @Test
+ public void verifyHealthTakesPriority() {
+ testWith(UP, DOWN, UP);
+ testWith(NOT_CHECKED, DOWN, DOWN);
+ testWith(NOT_CHECKED, NOT_CHECKED, NOT_CHECKED);
}
- private void testWith(boolean nodeAdminInContainer,
- ApplicationId applicationId,
- ClusterId clusterId,
- ServiceType serviceType,
- int expectedSlobrokCalls,
- int expectedHealthCalls) {
- SlobrokMonitorManagerImpl slobrokMonitorManager = mock(SlobrokMonitorManagerImpl.class);
- HealthMonitorManager healthMonitorManager = mock(HealthMonitorManager.class);
-
- ConfigserverConfig.Builder builder = new ConfigserverConfig.Builder();
- builder.nodeAdminInContainer(nodeAdminInContainer);
- ConfigserverConfig config = new ConfigserverConfig(builder);
-
-
- UnionMonitorManager manager = new UnionMonitorManager(
- slobrokMonitorManager,
- healthMonitorManager,
- config);
-
- manager.getStatus(applicationId, clusterId, serviceType, new ConfigId("config-id"));
-
- verify(slobrokMonitorManager, times(expectedSlobrokCalls)).getStatus(any(), any(), any(), any());
- verify(healthMonitorManager, times(expectedHealthCalls)).getStatus(any(), any(), any(), any());
+ private void testWith(ServiceStatus healthStatus,
+ ServiceStatus slobrokStatus,
+ ServiceStatus expectedStatus) {
+ when(healthMonitorManager.getStatus(any(), any(), any(), any())).thenReturn(healthStatus);
+ when(slobrokMonitorManager.getStatus(any(), any(), any(), any())).thenReturn(slobrokStatus);
+ ServiceStatus status = manager.getStatus(ZONE_APPLICATION_ID, NODE_ADMIN, CONTAINER, new ConfigId("config-id"));
+ assertSame(expectedStatus, status);
}
} \ No newline at end of file
diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitorTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitorTest.java
new file mode 100644
index 00000000000..51b0503565f
--- /dev/null
+++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitorTest.java
@@ -0,0 +1,24 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal.health;
+
+import com.yahoo.vespa.applicationmodel.ServiceStatus;
+import com.yahoo.vespa.service.monitor.application.ConfigServerApplication;
+import com.yahoo.vespa.service.monitor.internal.ConfigserverUtil;
+import org.junit.Test;
+
+import static com.yahoo.vespa.applicationmodel.ServiceStatus.NOT_CHECKED;
+import static org.junit.Assert.assertEquals;
+
+public class ApplicationHealthMonitorTest {
+ @Test
+ public void sanityCheck() {
+ ApplicationHealthMonitor monitor = ApplicationHealthMonitor.startMonitoring(
+ ConfigserverUtil.makeExampleConfigServer());
+ ServiceStatus status = monitor.getStatus(
+ ConfigServerApplication.CONFIG_SERVER_APPLICATION.getApplicationId(),
+ ConfigServerApplication.CLUSTER_ID,
+ ConfigServerApplication.SERVICE_TYPE,
+ ConfigServerApplication.configIdFrom(0));
+ assertEquals(NOT_CHECKED, status);
+ }
+} \ No newline at end of file
diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManagerTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManagerTest.java
new file mode 100644
index 00000000000..b9d25406f9b
--- /dev/null
+++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManagerTest.java
@@ -0,0 +1,49 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal.health;
+
+import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.config.model.api.ApplicationInfo;
+import com.yahoo.vespa.applicationmodel.ClusterId;
+import com.yahoo.vespa.applicationmodel.ConfigId;
+import com.yahoo.vespa.applicationmodel.ServiceStatus;
+import com.yahoo.vespa.applicationmodel.ServiceType;
+import com.yahoo.vespa.service.monitor.application.ZoneApplication;
+import com.yahoo.vespa.service.monitor.internal.ConfigserverUtil;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+public class HealthMonitorManagerTest {
+ @Test
+ public void addRemove() {
+ ConfigserverConfig config = ConfigserverUtil.createExampleConfigserverConfig();
+ HealthMonitorManager manager = new HealthMonitorManager(config);
+ ApplicationInfo applicationInfo = ConfigserverUtil.makeExampleConfigServer();
+ manager.applicationActivated(applicationInfo);
+ manager.applicationRemoved(applicationInfo.getApplicationId());
+ }
+
+ @Test
+ public void withNodeAdmin() {
+ ConfigserverConfig config = ConfigserverUtil.createExampleConfigserverConfig();
+ HealthMonitorManager manager = new HealthMonitorManager(config);
+ ServiceStatus status = manager.getStatus(
+ ZoneApplication.ZONE_APPLICATION_ID,
+ ClusterId.NODE_ADMIN,
+ ServiceType.CONTAINER,
+ new ConfigId("config-id-1"));
+ assertEquals(ServiceStatus.NOT_CHECKED, status);
+ }
+
+ @Test
+ public void withHostAdmin() {
+ ConfigserverConfig config = ConfigserverUtil.createExampleConfigserverConfig(false, true);
+ HealthMonitorManager manager = new HealthMonitorManager(config);
+ ServiceStatus status = manager.getStatus(
+ ZoneApplication.ZONE_APPLICATION_ID,
+ ClusterId.NODE_ADMIN,
+ ServiceType.CONTAINER,
+ new ConfigId("config-id-1"));
+ assertEquals(ServiceStatus.UP, status);
+ }
+} \ No newline at end of file
diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorTest.java
new file mode 100644
index 00000000000..cca1530ad97
--- /dev/null
+++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorTest.java
@@ -0,0 +1,21 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.service.monitor.internal.health;
+
+import com.yahoo.vespa.applicationmodel.ServiceStatus;
+import org.junit.Test;
+
+import java.net.MalformedURLException;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+
+public class HealthMonitorTest {
+ @Test
+ public void basicTests() throws MalformedURLException {
+ HealthClient healthClient = mock(HealthClient.class);
+ try (HealthMonitor monitor = new HealthMonitor(healthClient)) {
+ monitor.startMonitoring();
+ assertEquals(ServiceStatus.NOT_CHECKED, monitor.getStatus());
+ }
+ }
+} \ No newline at end of file
diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImplTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImplTest.java
index 8e4443df83b..a567559980b 100644
--- a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImplTest.java
+++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImplTest.java
@@ -2,7 +2,7 @@
package com.yahoo.vespa.service.monitor.internal.slobrok;
import com.yahoo.config.model.api.ApplicationInfo;
-import com.yahoo.config.model.api.SuperModel;
+import com.yahoo.config.provision.ApplicationId;
import com.yahoo.vespa.applicationmodel.ClusterId;
import com.yahoo.vespa.applicationmodel.ConfigId;
import com.yahoo.vespa.applicationmodel.ServiceStatus;
@@ -28,18 +28,19 @@ public class SlobrokMonitorManagerImplTest {
private final SlobrokMonitorManagerImpl slobrokMonitorManager =
new SlobrokMonitorManagerImpl(slobrokMonitorFactory);
private final SlobrokMonitor slobrokMonitor = mock(SlobrokMonitor.class);
- private final SuperModel superModel = mock(SuperModel.class);
+ private final ApplicationId applicationId = ApplicationId.from("tenant", "app", "instance");
private final ApplicationInfo application = mock(ApplicationInfo.class);
private final ClusterId clusterId = new ClusterId("cluster-id");
@Before
public void setup() {
when(slobrokMonitorFactory.get()).thenReturn(slobrokMonitor);
+ when(application.getApplicationId()).thenReturn(applicationId);
}
@Test
public void testActivationOfApplication() {
- slobrokMonitorManager.applicationActivated(superModel, application);
+ slobrokMonitorManager.applicationActivated(application);
verify(slobrokMonitorFactory, times(1)).get();
}
@@ -51,14 +52,14 @@ public class SlobrokMonitorManagerImplTest {
@Test
public void testGetStatus_ApplicationInSlobrok() {
- slobrokMonitorManager.applicationActivated(superModel, application);
+ slobrokMonitorManager.applicationActivated(application);
when(slobrokMonitor.registeredInSlobrok("config.id")).thenReturn(true);
assertEquals(ServiceStatus.UP, getStatus("topleveldispatch"));
}
@Test
public void testGetStatus_ServiceNotInSlobrok() {
- slobrokMonitorManager.applicationActivated(superModel, application);
+ slobrokMonitorManager.applicationActivated(application);
when(slobrokMonitor.registeredInSlobrok("config.id")).thenReturn(false);
assertEquals(ServiceStatus.DOWN, getStatus("topleveldispatch"));
}
diff --git a/valgrind-suppressions.txt b/valgrind-suppressions.txt
index 2df6c9c5691..2587552ceff 100644
--- a/valgrind-suppressions.txt
+++ b/valgrind-suppressions.txt
@@ -339,3 +339,20 @@
fun:__static_initialization_and_destruction_0
...
}
+{
+ Apparent memory leak on Fedora 28.
+ Memcheck:Leak
+ match-leak-kinds: possible
+ fun:malloc
+ fun:tsearch
+ fun:__add_to_environ
+ fun:setenv
+}
+{
+ Apparent memory leak on Fedora 28.
+ Memcheck:Leak
+ match-leak-kinds: possible
+ fun:malloc
+ fun:__add_to_environ
+ fun:setenv
+}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java
index 1504119d9cc..ab127b19bf1 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java
@@ -10,8 +10,13 @@ import com.yahoo.vespa.athenz.identityprovider.api.bindings.SignedIdentityDocume
import com.yahoo.vespa.athenz.identityprovider.api.bindings.VespaUniqueInstanceIdEntity;
import com.yahoo.vespa.athenz.utils.AthenzIdentities;
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.nio.file.Path;
import java.util.Base64;
+import static com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId.fromDottedString;
+
/**
* Utility class for mapping objects model types and their Jackson binding versions.
*
@@ -33,7 +38,7 @@ public class EntityBindingsMapper {
public static VespaUniqueInstanceId toVespaUniqueInstanceId(VespaUniqueInstanceIdEntity entity) {
return new VespaUniqueInstanceId(
- entity.clusterIndex, entity.clusterId, entity.instance, entity.application, entity.tenant, entity.region, entity.environment);
+ entity.clusterIndex, entity.clusterId, entity.instance, entity.application, entity.tenant, entity.region, entity.environment, entity.type != null ? IdentityType.fromId(entity.type) : null); // TODO Remove support for legacy representation without type
}
public static IdentityDocument toIdentityDocument(IdentityDocumentEntity entity) {
@@ -50,17 +55,22 @@ public class EntityBindingsMapper {
toIdentityDocument(entity.identityDocument),
entity.signature,
entity.signingKeyVersion,
- VespaUniqueInstanceId.fromDottedString(entity.providerUniqueId),
+ fromDottedString(entity.providerUniqueId),
entity.dnsSuffix,
(AthenzService) AthenzIdentities.from(entity.providerService),
entity.ztsEndpoint,
- entity.documentVersion);
+ entity.documentVersion,
+ entity.configServerHostname,
+ entity.instanceHostname,
+ entity.createdAt,
+ entity.ipAddresses,
+ entity.identityType != null ? IdentityType.fromId(entity.identityType) : null); // TODO Remove support for legacy representation without type
}
public static VespaUniqueInstanceIdEntity toVespaUniqueInstanceIdEntity(VespaUniqueInstanceId model) {
return new VespaUniqueInstanceIdEntity(
model.tenant(), model.application(), model.environment(), model.region(),
- model.instance(), model.clusterId(), model.clusterIndex());
+ model.instance(), model.clusterId(), model.clusterIndex(), model.type() != null ? model.type().id() : null); // TODO Remove support for legacy representation without type
}
public static IdentityDocumentEntity toIdentityDocumentEntity(IdentityDocument model) {
@@ -84,10 +94,33 @@ public class EntityBindingsMapper {
model.dnsSuffix(),
model.providerService().getFullName(),
model.ztsEndpoint(),
- model.documentVersion());
+ model.documentVersion(),
+ model.configServerHostname(),
+ model.instanceHostname(),
+ model.createdAt(),
+ model.ipAddresses(),
+ model.identityType() != null ? model.identityType().id() : null); // TODO Remove support for legacy representation without type
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
+ public static SignedIdentityDocument readSignedIdentityDocumentFromFile(Path file) {
+ try {
+ SignedIdentityDocumentEntity entity = mapper.readValue(file.toFile(), SignedIdentityDocumentEntity.class);
+ return EntityBindingsMapper.toSignedIdentityDocument(entity);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ public static void writeSignedIdentityDocumentToFile(Path file, SignedIdentityDocument document) {
+ try {
+ SignedIdentityDocumentEntity entity = EntityBindingsMapper.toSignedIdentityDocumentEntity(document);
+ mapper.writeValue(file.toFile(), entity);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java
index 8da2bd0a343..82d0a3d622c 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java
@@ -8,7 +8,9 @@ import java.util.Set;
* The identity document that contains the instance specific information
*
* @author bjorncs
+ * @deprecated Will soon be inlined into {@link SignedIdentityDocument}
*/
+@Deprecated
public class IdentityDocument {
private final VespaUniqueInstanceId providerUniqueId;
private final String configServerHostname;
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityType.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityType.java
new file mode 100644
index 00000000000..4ca2e34a618
--- /dev/null
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityType.java
@@ -0,0 +1,25 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.athenz.identityprovider.api;
+
+import java.util.Arrays;
+
+/**
+ * Represents the types of identities that the configserver can provide.
+ *
+ * @author bjorncs
+ */
+public enum IdentityType {TENANT("tenant"), NODE("node");
+ private final String id;
+
+ IdentityType(String id) { this.id = id; }
+
+ public String id() { return id; }
+
+ public static IdentityType fromId(String id) {
+ return Arrays.stream(values())
+ .filter(v -> v.id.equals(id))
+ .findFirst()
+ .orElseThrow(() -> new IllegalArgumentException("Invalid id: " + id));
+ }
+}
+
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java
index d184efc0221..60be42544c7 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java
@@ -4,6 +4,8 @@ package com.yahoo.vespa.athenz.identityprovider.api;
import com.yahoo.vespa.athenz.api.AthenzService;
import java.net.URI;
+import java.time.Instant;
+import java.util.Set;
/**
* A signed identity document which contains a {@link IdentityDocument}
@@ -22,6 +24,11 @@ public class SignedIdentityDocument {
private final AthenzService providerService;
private final URI ztsEndpoint;
private final int documentVersion;
+ private final String configServerHostname;
+ private final String instanceHostname;
+ private final Instant createdAt;
+ private final Set<String> ipAddresses;
+ private final IdentityType identityType;
public SignedIdentityDocument(IdentityDocument identityDocument,
String signature,
@@ -30,7 +37,12 @@ public class SignedIdentityDocument {
String dnsSuffix,
AthenzService providerService,
URI ztsEndpoint,
- int documentVersion) {
+ int documentVersion,
+ String configServerHostname,
+ String instanceHostname,
+ Instant createdAt,
+ Set<String> ipAddresses,
+ IdentityType identityType) {
this.identityDocument = identityDocument;
this.signature = signature;
this.signingKeyVersion = signingKeyVersion;
@@ -39,6 +51,11 @@ public class SignedIdentityDocument {
this.providerService = providerService;
this.ztsEndpoint = ztsEndpoint;
this.documentVersion = documentVersion;
+ this.configServerHostname = configServerHostname;
+ this.instanceHostname = instanceHostname;
+ this.createdAt = createdAt;
+ this.ipAddresses = ipAddresses;
+ this.identityType = identityType;
}
public IdentityDocument identityDocument() {
@@ -72,4 +89,24 @@ public class SignedIdentityDocument {
public int documentVersion() {
return documentVersion;
}
+
+ public String configServerHostname() {
+ return configServerHostname;
+ }
+
+ public String instanceHostname() {
+ return instanceHostname;
+ }
+
+ public Instant createdAt() {
+ return createdAt;
+ }
+
+ public Set<String> ipAddresses() {
+ return ipAddresses;
+ }
+
+ public IdentityType identityType() {
+ return identityType;
+ }
}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceId.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceId.java
index 5539ba53882..be94cc59691 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceId.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceId.java
@@ -4,6 +4,8 @@ package com.yahoo.vespa.athenz.identityprovider.api;
import java.util.Objects;
/**
+ * Represents the unique instance id as used in Vespa's integration with Athenz Copper Argos
+ *
* @author bjorncs
*/
public class VespaUniqueInstanceId {
@@ -15,6 +17,7 @@ public class VespaUniqueInstanceId {
private final String tenant;
private final String region;
private final String environment;
+ private final IdentityType type;
public VespaUniqueInstanceId(int clusterIndex,
String clusterId,
@@ -22,7 +25,8 @@ public class VespaUniqueInstanceId {
String application,
String tenant,
String region,
- String environment) {
+ String environment,
+ IdentityType type) {
this.clusterIndex = clusterIndex;
this.clusterId = clusterId;
this.instance = instance;
@@ -30,21 +34,43 @@ public class VespaUniqueInstanceId {
this.tenant = tenant;
this.region = region;
this.environment = environment;
+ this.type = type;
}
+ // TODO Remove support for legacy representation without type
+ @Deprecated
+ public VespaUniqueInstanceId(int clusterIndex,
+ String clusterId,
+ String instance,
+ String application,
+ String tenant,
+ String region,
+ String environment) {
+ this(clusterIndex, clusterId, instance, application, tenant, region, environment, null);
+ }
+
+
+ // TODO Remove support for legacy representation without type
public static VespaUniqueInstanceId fromDottedString(String instanceId) {
String[] tokens = instanceId.split("\\.");
- if (tokens.length != 7) {
+ if (tokens.length != 7 && tokens.length != 8) {
throw new IllegalArgumentException("Invalid instance id: " + instanceId);
}
return new VespaUniqueInstanceId(
- Integer.parseInt(tokens[0]), tokens[1], tokens[2], tokens[3], tokens[4], tokens[5], tokens[6]);
+ Integer.parseInt(tokens[0]), tokens[1], tokens[2], tokens[3], tokens[4], tokens[5], tokens[6], tokens.length == 8 ? IdentityType.fromId(tokens[7]) : null);
}
+ // TODO Remove support for legacy representation without type
public String asDottedString() {
- return String.format(
- "%d.%s.%s.%s.%s.%s.%s",
- clusterIndex, clusterId, instance, application, tenant, region, environment);
+ if (type != null) {
+ return String.format(
+ "%d.%s.%s.%s.%s.%s.%s.%s",
+ clusterIndex, clusterId, instance, application, tenant, region, environment, type.id());
+ } else {
+ return String.format(
+ "%d.%s.%s.%s.%s.%s.%s",
+ clusterIndex, clusterId, instance, application, tenant, region, environment);
+ }
}
public int clusterIndex() {
@@ -75,6 +101,8 @@ public class VespaUniqueInstanceId {
return environment;
}
+ public IdentityType type() { return type; }
+
@Override
public String toString() {
return "VespaUniqueInstanceId{" +
@@ -85,6 +113,7 @@ public class VespaUniqueInstanceId {
", tenant='" + tenant + '\'' +
", region='" + region + '\'' +
", environment='" + environment + '\'' +
+ ", type=" + type +
'}';
}
@@ -99,11 +128,12 @@ public class VespaUniqueInstanceId {
Objects.equals(application, that.application) &&
Objects.equals(tenant, that.tenant) &&
Objects.equals(region, that.region) &&
- Objects.equals(environment, that.environment);
+ Objects.equals(environment, that.environment) &&
+ type == that.type;
}
@Override
public int hashCode() {
- return Objects.hash(clusterIndex, clusterId, instance, application, tenant, region, environment);
+ return Objects.hash(clusterIndex, clusterId, instance, application, tenant, region, environment, type);
}
}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentApi.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentApi.java
index 775a49349a3..fc5392411c1 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentApi.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentApi.java
@@ -5,7 +5,6 @@ import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
-import javax.ws.rs.QueryParam;
import javax.ws.rs.core.MediaType;
/**
@@ -16,11 +15,6 @@ public interface IdentityDocumentApi {
@GET
@Produces(MediaType.APPLICATION_JSON)
- @Deprecated
- SignedIdentityDocumentEntity getIdentityDocument(@QueryParam("hostname") String hostname);
-
- @GET
- @Produces(MediaType.APPLICATION_JSON)
@Path("/node/{host}")
SignedIdentityDocumentEntity getNodeIdentityDocument(@PathParam("host") String host);
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java
index 58a4f1e24bf..b4b2e82ab0e 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java
@@ -10,8 +10,10 @@ import java.util.Set;
/**
* @author bjorncs
+ * @deprecated Will soon be inlined into {@link SignedIdentityDocumentEntity}
*/
@JsonIgnoreProperties(ignoreUnknown = true)
+@Deprecated
public class IdentityDocumentEntity {
@JsonProperty("provider-unique-id")
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java
index e397b81ef9e..aa514b3caf3 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java
@@ -11,8 +11,10 @@ import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
+import java.time.Instant;
import java.util.Base64;
import java.util.Objects;
+import java.util.Set;
/**
* @author bjorncs
@@ -31,6 +33,11 @@ public class SignedIdentityDocumentEntity {
@JsonProperty("provider-service") public final String providerService;
@JsonProperty("zts-endpoint") public final URI ztsEndpoint;
@JsonProperty("document-version") public final int documentVersion;
+ @JsonProperty("configserver-hostname") public final String configServerHostname;
+ @JsonProperty("instance-hostname") public final String instanceHostname;
+ @JsonProperty("created-at") public final Instant createdAt;
+ @JsonProperty("ip-addresses") public final Set<String> ipAddresses;
+ @JsonProperty("identity-type") public final String identityType;
@JsonCreator
public SignedIdentityDocumentEntity(@JsonProperty("identity-document") String rawIdentityDocument,
@@ -40,7 +47,12 @@ public class SignedIdentityDocumentEntity {
@JsonProperty("dns-suffix") String dnsSuffix,
@JsonProperty("provider-service") String providerService,
@JsonProperty("zts-endpoint") URI ztsEndpoint,
- @JsonProperty("document-version") int documentVersion) {
+ @JsonProperty("document-version") int documentVersion,
+ @JsonProperty("configserver-hostname") String configServerHostname,
+ @JsonProperty("instance-hostname") String instanceHostname,
+ @JsonProperty("created-at") Instant createdAt,
+ @JsonProperty("ip-addresses") Set<String> ipAddresses,
+ @JsonProperty("identity-type") String identityType) {
this.rawIdentityDocument = rawIdentityDocument;
this.identityDocument = parseIdentityDocument(rawIdentityDocument);
this.signature = signature;
@@ -50,6 +62,11 @@ public class SignedIdentityDocumentEntity {
this.providerService = providerService;
this.ztsEndpoint = ztsEndpoint;
this.documentVersion = documentVersion;
+ this.configServerHostname = configServerHostname;
+ this.instanceHostname = instanceHostname;
+ this.createdAt = createdAt;
+ this.ipAddresses = ipAddresses;
+ this.identityType = identityType;
}
private static IdentityDocumentEntity parseIdentityDocument(String rawIdentityDocument) {
@@ -73,7 +90,16 @@ public class SignedIdentityDocumentEntity {
", identityDocument=" + identityDocument +
", signature='" + signature + '\'' +
", signingKeyVersion=" + signingKeyVersion +
+ ", providerUniqueId='" + providerUniqueId + '\'' +
+ ", dnsSuffix='" + dnsSuffix + '\'' +
+ ", providerService='" + providerService + '\'' +
+ ", ztsEndpoint=" + ztsEndpoint +
", documentVersion=" + documentVersion +
+ ", configServerHostname='" + configServerHostname + '\'' +
+ ", instanceHostname='" + instanceHostname + '\'' +
+ ", createdAt=" + createdAt +
+ ", ipAddresses=" + ipAddresses +
+ ", identityType=" + identityType +
'}';
}
@@ -86,11 +112,20 @@ public class SignedIdentityDocumentEntity {
documentVersion == that.documentVersion &&
Objects.equals(rawIdentityDocument, that.rawIdentityDocument) &&
Objects.equals(identityDocument, that.identityDocument) &&
- Objects.equals(signature, that.signature);
+ Objects.equals(signature, that.signature) &&
+ Objects.equals(providerUniqueId, that.providerUniqueId) &&
+ Objects.equals(dnsSuffix, that.dnsSuffix) &&
+ Objects.equals(providerService, that.providerService) &&
+ Objects.equals(ztsEndpoint, that.ztsEndpoint) &&
+ Objects.equals(configServerHostname, that.configServerHostname) &&
+ Objects.equals(instanceHostname, that.instanceHostname) &&
+ Objects.equals(createdAt, that.createdAt) &&
+ Objects.equals(ipAddresses, that.ipAddresses) &&
+ Objects.equals(identityType, identityType);
}
@Override
public int hashCode() {
- return Objects.hash(rawIdentityDocument, identityDocument, signature, signingKeyVersion, documentVersion);
+ return Objects.hash(rawIdentityDocument, identityDocument, signature, signingKeyVersion, providerUniqueId, dnsSuffix, providerService, ztsEndpoint, documentVersion, configServerHostname, instanceHostname, createdAt, ipAddresses, identityType);
}
}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/VespaUniqueInstanceIdEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/VespaUniqueInstanceIdEntity.java
index 3c521e992ad..3fdbb49b28e 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/VespaUniqueInstanceIdEntity.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/VespaUniqueInstanceIdEntity.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.athenz.identityprovider.api.bindings;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Objects;
@@ -26,14 +27,18 @@ public class VespaUniqueInstanceIdEntity {
public final String clusterId;
@JsonProperty("cluster-index")
public final int clusterIndex;
+ @JsonProperty("type")
+ public final String type;
+ @JsonCreator
public VespaUniqueInstanceIdEntity(@JsonProperty("tenant") String tenant,
@JsonProperty("application") String application,
@JsonProperty("environment") String environment,
@JsonProperty("region") String region,
@JsonProperty("instance") String instance,
@JsonProperty("cluster-id") String clusterId,
- @JsonProperty("cluster-index") int clusterIndex) {
+ @JsonProperty("cluster-index") int clusterIndex,
+ @JsonProperty("type") String type) {
this.tenant = tenant;
this.application = application;
this.environment = environment;
@@ -41,8 +46,21 @@ public class VespaUniqueInstanceIdEntity {
this.instance = instance;
this.clusterId = clusterId;
this.clusterIndex = clusterIndex;
+ this.type = type;
}
+ @Deprecated
+ public VespaUniqueInstanceIdEntity(String tenant,
+ String application,
+ String environment,
+ String region,
+ String instance,
+ String clusterId,
+ int clusterIndex) {
+ this(tenant, application, environment, region, instance, clusterId, clusterIndex, null);
+ }
+
+
@Override
public String toString() {
return "VespaUniqueInstanceIdEntity{" +
@@ -53,6 +71,7 @@ public class VespaUniqueInstanceIdEntity {
", instance='" + instance + '\'' +
", clusterId='" + clusterId + '\'' +
", clusterIndex=" + clusterIndex +
+ ", type='" + type + '\'' +
'}';
}
@@ -67,11 +86,12 @@ public class VespaUniqueInstanceIdEntity {
Objects.equals(environment, that.environment) &&
Objects.equals(region, that.region) &&
Objects.equals(instance, that.instance) &&
- Objects.equals(clusterId, that.clusterId);
+ Objects.equals(clusterId, that.clusterId) &&
+ Objects.equals(type, that.type);
}
@Override
public int hashCode() {
- return Objects.hash(tenant, application, environment, region, instance, clusterId, clusterIndex);
+ return Objects.hash(tenant, application, environment, region, instance, clusterId, clusterIndex, type);
}
}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java
index 96e93ca419d..e8ef2d9f97e 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.athenz.identityprovider.client;
import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.yahoo.container.core.identity.IdentityConfig;
import com.yahoo.vespa.athenz.api.AthenzService;
import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper;
@@ -28,7 +29,7 @@ import static com.yahoo.vespa.athenz.tls.KeyStoreType.JKS;
*/
class AthenzCredentialsService {
- private static final ObjectMapper mapper = new ObjectMapper();
+ private static final ObjectMapper mapper = new ObjectMapper().registerModule(new JavaTimeModule());
private final IdentityConfig identityConfig;
private final IdentityDocumentClient identityDocumentClient;
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java
index 90d1312c9f9..b9aba6e66b0 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java
@@ -2,14 +2,12 @@
package com.yahoo.vespa.athenz.identityprovider.client;
import com.fasterxml.jackson.databind.ObjectMapper;
-import com.yahoo.vespa.athenz.api.AthenzService;
+import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider;
import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper;
import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocumentClient;
import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument;
-import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId;
import com.yahoo.vespa.athenz.identityprovider.api.bindings.SignedIdentityDocumentEntity;
-import com.yahoo.vespa.athenz.utils.AthenzIdentities;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.methods.RequestBuilder;
@@ -34,7 +32,7 @@ import java.util.function.Supplier;
public class DefaultIdentityDocumentClient implements IdentityDocumentClient {
private static final String IDENTITY_DOCUMENT_API = "/athenz/v1/provider/identity-document/";
- private static final ObjectMapper objectMapper = new ObjectMapper();
+ private static final ObjectMapper objectMapper = new ObjectMapper().registerModule(new JavaTimeModule());
private final Supplier<SSLContext> sslContextSupplier;
private final HostnameVerifier hostnameVerifier;
@@ -82,15 +80,7 @@ public class DefaultIdentityDocumentClient implements IdentityDocumentClient {
String responseContent = EntityUtils.toString(response.getEntity());
if (HttpStatus.isSuccess(response.getStatusLine().getStatusCode())) {
SignedIdentityDocumentEntity entity = objectMapper.readValue(responseContent, SignedIdentityDocumentEntity.class);
- return new SignedIdentityDocument(
- EntityBindingsMapper.toIdentityDocument(entity.identityDocument),
- entity.signature,
- entity.signingKeyVersion,
- VespaUniqueInstanceId.fromDottedString(entity.providerUniqueId),
- entity.dnsSuffix,
- (AthenzService) AthenzIdentities.from(entity.providerService),
- entity.ztsEndpoint,
- entity.documentVersion);
+ return EntityBindingsMapper.toSignedIdentityDocument(entity);
} else {
throw new RuntimeException(
String.format(
diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceIdTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceIdTest.java
index 8c4e4c1262d..86b6c566987 100644
--- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceIdTest.java
+++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceIdTest.java
@@ -2,6 +2,7 @@ package com.yahoo.vespa.athenz.identityprovider.api;
import org.junit.Test;
+import static com.yahoo.vespa.athenz.identityprovider.api.IdentityType.*;
import static org.junit.Assert.*;
/**
@@ -12,6 +13,18 @@ public class VespaUniqueInstanceIdTest {
@Test
public void can_serialize_to_and_deserialize_from_string() {
VespaUniqueInstanceId id =
+ new VespaUniqueInstanceId(1, "cluster-id", "instance", "application", "tenant", "region", "environment", TENANT);
+ String stringRepresentation = id.asDottedString();
+ String expectedStringRepresentation = "1.cluster-id.instance.application.tenant.region.environment.tenant";
+ assertEquals(expectedStringRepresentation, stringRepresentation);
+ VespaUniqueInstanceId deserializedId = VespaUniqueInstanceId.fromDottedString(stringRepresentation);
+ assertEquals(id, deserializedId);
+ }
+
+ // TODO Remove support for legacy representation without type
+ @Test
+ public void supports_legacy_representation_without_type() {
+ VespaUniqueInstanceId id =
new VespaUniqueInstanceId(1, "cluster-id", "instance", "application", "tenant", "region", "environment");
String stringRepresentation = id.asDottedString();
String expectedStringRepresentation = "1.cluster-id.instance.application.tenant.region.environment";
diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java
index 2e9b29f5327..7ad465a7d80 100644
--- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java
+++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java
@@ -11,6 +11,7 @@ import com.yahoo.test.ManualClock;
import com.yahoo.vespa.athenz.api.AthenzService;
import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper;
import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocument;
+import com.yahoo.vespa.athenz.identityprovider.api.IdentityType;
import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument;
import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId;
import com.yahoo.vespa.athenz.tls.KeyStoreBuilder;
@@ -132,7 +133,7 @@ public class AthenzIdentityProviderImplTest {
}
private static String getIdentityDocument() throws JsonProcessingException {
- VespaUniqueInstanceId instanceId = new VespaUniqueInstanceId(0, "default", "default", "application", "tenant", "us-north-1", "dev");
+ VespaUniqueInstanceId instanceId = new VespaUniqueInstanceId(0, "default", "default", "application", "tenant", "us-north-1", "dev", IdentityType.TENANT);
SignedIdentityDocument signedIdentityDocument = new SignedIdentityDocument(
new IdentityDocument(instanceId, "localhost", "x.y.com", Instant.EPOCH, Collections.emptySet()),
"dummysignature",
@@ -141,7 +142,12 @@ public class AthenzIdentityProviderImplTest {
"dev-us-north-1.vespa.cloud",
new AthenzService("vespa.vespa.provider_dev_us-north-1"),
URI.create("https://zts:4443/zts/v1"),
- 1);
+ 1,
+ "localhost",
+ "x.y.com",
+ Instant.EPOCH,
+ Collections.emptySet(),
+ IdentityType.TENANT);
return new ObjectMapper().registerModule(new JavaTimeModule())
.writeValueAsString(EntityBindingsMapper.toSignedIdentityDocumentEntity(signedIdentityDocument));
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/CommandLineArguments.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/CommandLineArguments.java
index 671038c852a..84d3b320772 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/CommandLineArguments.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/CommandLineArguments.java
@@ -11,10 +11,15 @@ import io.airlift.command.Command;
import io.airlift.command.HelpOption;
import io.airlift.command.Option;
import io.airlift.command.SingleCommand;
+import org.apache.http.Header;
+import org.apache.http.ParseException;
import org.apache.http.conn.ssl.NoopHostnameVerifier;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
+import org.apache.http.message.BasicLineParser;
import javax.inject.Inject;
+import java.util.ArrayList;
+import java.util.List;
import java.util.concurrent.TimeUnit;
/**
@@ -53,6 +58,15 @@ public class CommandLineArguments {
return null;
}
+ for (String header : cmdArgs.headers) {
+ try {
+ cmdArgs.parsedHeaders.add(BasicLineParser.parseHeader(header, null));
+ } catch (ParseException e) {
+ System.err.printf("Invalid header: '%s' (%s)%n", header, e.getMessage());
+ return null;
+ }
+ }
+
return cmdArgs;
}
@@ -180,6 +194,12 @@ public class CommandLineArguments {
description = "Skip hostname verification when using TLS")
private boolean insecure = false;
+ @Option(name = {"--header"},
+ description = "Add http header to every request. Header must have the format '<Name>: <Value>'. Use this parameter multiple times for multiple headers")
+ private List<String> headers = new ArrayList<>();
+
+ private final List<Header> parsedHeaders = new ArrayList<>();
+
int getWhenVerboseEnabledPrintMessageForEveryXDocuments() {
return whenVerboseEnabledPrintMessageForEveryXDocuments;
}
@@ -192,6 +212,8 @@ public class CommandLineArguments {
SessionParams createSessionParams(boolean useJson) {
final int minThrottleValue = useDynamicThrottlingArg ? 10 : 0;
+ ConnectionParams.Builder connectionParamsBuilder = new ConnectionParams.Builder();
+ parsedHeaders.forEach(header -> connectionParamsBuilder.addHeader(header.getName(), header.getValue()));
SessionParams.Builder builder = new SessionParams.Builder()
.setFeedParams(
new FeedParams.Builder()
@@ -208,7 +230,7 @@ public class CommandLineArguments {
.build()
)
.setConnectionParams(
- new ConnectionParams.Builder()
+ connectionParamsBuilder
.setHostnameVerifier(insecure ? NoopHostnameVerifier.INSTANCE :
SSLConnectionSocketFactory.getDefaultHostnameVerifier())
.setNumPersistentConnectionsPerEndpoint(16)
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/runner/CommandLineArgumentsTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/runner/CommandLineArgumentsTest.java
index e0d93a7fa18..84a69520a84 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/runner/CommandLineArgumentsTest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/runner/CommandLineArgumentsTest.java
@@ -7,7 +7,13 @@ import com.yahoo.vespa.http.client.config.SessionParams;
import org.junit.Test;
import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
@@ -109,6 +115,7 @@ public class CommandLineArgumentsTest {
add("debugport", "7890");
args.add("--verbose");
args.add("--useTls");
+ add("header", "Header-Name: Header-Value");
CommandLineArguments arguments = CommandLineArguments.build(asArray());
SessionParams params = arguments.createSessionParams(true /* use json */);
assertThat(params.getClientQueueSize(), is(3456));
@@ -116,6 +123,7 @@ public class CommandLineArgumentsTest {
assertThat(params.getClusters().get(0).getEndpoints().get(0).getPort(), is(1234));
assertThat(params.getClusters().get(0).getEndpoints().get(0).isUseSsl(), is(true));
assertThat(params.getConnectionParams().getUseCompression(), is(true));
+ assertThat(params.getConnectionParams().getHeaders().size(), is(1));
assertThat(params.getFeedParams().getRoute(), is("routeValue"));
assertThat(params.getFeedParams().getDataFormat(), is(FeedParams.DataFormat.JSON_UTF8));
assertThat(params.getFeedParams().getLocalQueueTimeOut(), is(2345000L));
@@ -124,6 +132,31 @@ public class CommandLineArgumentsTest {
}
@Test
+ public void testAddingMultipleHttpHeaders() {
+ add("host", "hostValue");
+ String header1Name = "Header-Name-1";
+ String header1Value = "Header-Value";
+ add("header", header1Name + ": " + header1Value);
+ String header2Name = "Header-Name-2";
+ String header2Value = "Another-Header-Value";
+ add("header", header2Name + ": " + header2Value);
+
+ CommandLineArguments arguments = CommandLineArguments.build(asArray());
+ SessionParams params = arguments.createSessionParams(true /* use json */);
+
+ List<Map.Entry<String, String>> headers = new ArrayList<>(params.getConnectionParams().getHeaders());
+ headers.sort(Comparator.comparing(Map.Entry::getKey));
+
+ assertThat(headers.size(), is(2));
+ Map.Entry<String, String> actualHeader1 = headers.get(0);
+ assertThat(actualHeader1.getKey(), is(header1Name));
+ assertThat(actualHeader1.getValue(), is(header1Value));
+ Map.Entry<String, String> actualHeader2 = headers.get(1);
+ assertThat(actualHeader2.getKey(), is(header2Name));
+ assertThat(actualHeader2.getValue(), is(header2Value));
+ }
+
+ @Test
public void testMultiHost() {
add("file", "fileValue.json");
add("port", "1234");
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 944755c9db2..3a66eef258d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -22,22 +22,37 @@ public class ScalarFunctions {
public static DoubleBinaryOperator add() { return new Add(); }
public static DoubleBinaryOperator divide() { return new Divide(); }
public static DoubleBinaryOperator equal() { return new Equal(); }
+ public static DoubleBinaryOperator greater() { return new Greater(); }
+ public static DoubleBinaryOperator less() { return new Less(); }
public static DoubleBinaryOperator max() { return new Max(); }
public static DoubleBinaryOperator min() { return new Min(); }
+ public static DoubleBinaryOperator mean() { return new Mean(); }
public static DoubleBinaryOperator multiply() { return new Multiply(); }
+ public static DoubleBinaryOperator pow() { return new Pow(); }
public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); }
public static DoubleBinaryOperator subtract() { return new Subtract(); }
+ public static DoubleUnaryOperator abs() { return new Abs(); }
public static DoubleUnaryOperator acos() { return new Acos(); }
+ public static DoubleUnaryOperator asin() { return new Asin(); }
+ public static DoubleUnaryOperator atan() { return new Atan(); }
+ public static DoubleUnaryOperator ceil() { return new Ceil(); }
+ public static DoubleUnaryOperator cos() { return new Cos(); }
public static DoubleUnaryOperator elu() { return new Elu(); }
public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator floor() { return new Floor(); }
+ public static DoubleUnaryOperator log() { return new Log(); }
+ public static DoubleUnaryOperator neg() { return new Neg(); }
+ public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); }
public static DoubleUnaryOperator relu() { return new Relu(); }
public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); }
public static DoubleUnaryOperator selu() { return new Selu(); }
+ public static DoubleUnaryOperator sin() { return new Sin(); }
public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator square() { return new Square(); }
+ public static DoubleUnaryOperator tan() { return new Tan(); }
+ public static DoubleUnaryOperator tanh() { return new Tanh(); }
public static Function<List<Long>, Double> random() { return new Random(); }
public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
@@ -59,6 +74,20 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(a==b)"; }
}
+ public static class Greater implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; }
+ @Override
+ public String toString() { return "f(a,b)(a > b)"; }
+ }
+
+ public static class Less implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; }
+ @Override
+ public String toString() { return "f(a,b)(a < b)"; }
+ }
+
public static class Max implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return Math.max(left, right); }
@@ -73,6 +102,13 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(min(a, b))"; }
}
+ public static class Mean implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return (left + right) / 2; }
+ @Override
+ public String toString() { return "f(a,b)((a + b) / 2)"; }
+ }
+
public static class Multiply implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left * right; }
@@ -80,6 +116,13 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(a * b)"; }
}
+ public static class Pow implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return Math.pow(left, right); }
+ @Override
+ public String toString() { return "f(a,b)(pow(a, b))"; }
+ }
+
public static class Divide implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left / right; }
@@ -104,6 +147,13 @@ public class ScalarFunctions {
// Unary operators ------------------------------------------------------------------------------
+ public static class Abs implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.abs(operand); }
+ @Override
+ public String toString() { return "f(a)(fabs(a))"; }
+ }
+
public static class Acos implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.acos(operand); }
@@ -111,6 +161,34 @@ public class ScalarFunctions {
public String toString() { return "f(a)(acos(a))"; }
}
+ public static class Asin implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.asin(operand); }
+ @Override
+ public String toString() { return "f(a)(asin(a))"; }
+ }
+
+ public static class Atan implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.atan(operand); }
+ @Override
+ public String toString() { return "f(a)(atan(a))"; }
+ }
+
+ public static class Ceil implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.ceil(operand); }
+ @Override
+ public String toString() { return "f(a)(ceil(a))"; }
+ }
+
+ public static class Cos implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.cos(operand); }
+ @Override
+ public String toString() { return "f(a)(cos(a))"; }
+ }
+
public static class Elu implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; }
@@ -132,6 +210,26 @@ public class ScalarFunctions {
public String toString() { return "f(a)(floor(a))"; }
}
+ public static class Log implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.log(operand); }
+ @Override
+ public String toString() { return "f(a)(log(a))"; }
+ }
+
+ public static class Neg implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return -operand; }
+ @Override
+ public String toString() { return "f(a)(-a)"; }
+ }
+
+ public static class Reciprocal implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return 1.0 / operand; }
+ @Override
+ public String toString() { return "f(a)(1 / a)"; }
+ }
public static class Relu implements DoubleUnaryOperator {
@Override
@@ -150,6 +248,13 @@ public class ScalarFunctions {
public String toString() { return String.format("f(a)(%f * if(a >= 0, a, %f*(exp(a)-1)))", scale, alpha); }
}
+ public static class Sin implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.sin(operand); }
+ @Override
+ public String toString() { return "f(a)(sin(a))"; }
+ }
+
public static class Rsqrt implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); }
@@ -172,15 +277,29 @@ public class ScalarFunctions {
}
public static class Square implements DoubleUnaryOperator {
-
@Override
public double applyAsDouble(double operand) { return operand * operand; }
-
@Override
public String toString() { return "f(a)(a * a)"; }
+ }
+ public static class Tan implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.tan(operand); }
+ @Override
+ public String toString() { return "f(a)(tan(a))"; }
}
+ public static class Tanh implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.tanh(operand); }
+ @Override
+ public String toString() { return "f(a)(tanh(a))"; }
+ }
+
+
+
+
// Variable-length operators -----------------------------------------------------------------------------
public static class EqualElements implements Function<List<Long>, Double> {