summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt1
-rw-r--r--athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/CkmsKeyProvider.java (renamed from athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/SecretStoreKeyProvider.java)19
-rw-r--r--clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java13
-rw-r--r--clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/DistributorNodeInfo.java20
-rw-r--r--clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/LatencyStats.java33
-rw-r--r--clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodes.java25
-rw-r--r--clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStats.java22
-rw-r--r--clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainer.java28
-rw-r--r--clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridge.java25
-rw-r--r--clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java9
-rw-r--r--clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodeTest.java44
-rw-r--r--clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainerTest.java43
-rw-r--r--clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsTest.java25
-rw-r--r--clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridgeTest.java20
-rw-r--r--clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/restapiv2/NodeTest.java20
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java110
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java167
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java80
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/Search.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java26
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java32
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java3
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableSDField.java7
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java3
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java82
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java7
-rw-r--r--config-model/src/test/derived/rankexpression/rank-profiles.cfg4
-rw-r--r--config-model/src/test/derived/rankexpression/rankexpression.sd38
-rw-r--r--config-model/src/test/derived/rankexpression/summary.cfg18
-rw-r--r--config-model/src/test/derived/rankexpression/summarymap.cfg26
-rw-r--r--config-model/src/test/derived/tensor/rank-profiles.cfg2
-rw-r--r--config-model/src/test/derived/tensor/tensor.sd2
-rw-r--r--config-model/src/test/examples/rankpropvars.sd8
-rw-r--r--config-model/src/test/examples/simple.sd2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java20
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java40
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java239
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java40
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java133
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/search/DocumentTypeChangeValidatorTest.java1
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java9
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResource.java14
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResourceTest.java8
-rw-r--r--docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerImpl.java120
-rw-r--r--docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerTestUtils.java94
-rw-r--r--docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/VespaSSLConfig.java228
-rw-r--r--docker-api/src/main/resources/configdefinitions/docker.def3
-rw-r--r--docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerImplTest.java61
-rw-r--r--docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerTest.java197
-rw-r--r--docker-api/src/test/resources/simple-ipv6-server/Dockerfile13
-rw-r--r--docker-api/src/test/resources/simple-ipv6-server/README10
-rw-r--r--docker-api/src/test/resources/simple-ipv6-server/src/fillmem.py11
-rw-r--r--docker-api/src/test/resources/simple-ipv6-server/src/server.py43
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/RealConfigServerClients.java (renamed from node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/ConfigServerClientsImpl.java)14
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java (renamed from node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepositoryImpl.java)20
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java2
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminMain.java4
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAttributes.java2
-rw-r--r--node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java (renamed from node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepositoryImplTest.java)16
-rw-r--r--searchcore/src/vespa/searchcore/fdispatch/common/queryperf.cpp58
-rw-r--r--searchcore/src/vespa/searchcore/fdispatch/common/queryperf.h3
-rw-r--r--searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp6
-rw-r--r--searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.h2
-rw-r--r--searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.cpp4
-rw-r--r--searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.h3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java31
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java5
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java121
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java38
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java132
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java715
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java155
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java396
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java30
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java210
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java108
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java237
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java224
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java93
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java107
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java30
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java79
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java38
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java74
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java112
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java36
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java57
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java50
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java135
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java89
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java55
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java84
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java48
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java136
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java40
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java18
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java74
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java111
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java33
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java10
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java7
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java22
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java49
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java7
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java12
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java10
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java4
-rw-r--r--storage/src/tests/common/testnodestateupdater.cpp6
-rw-r--r--storage/src/tests/common/testnodestateupdater.h4
-rw-r--r--storage/src/tests/persistence/filestorage/filestormanagertest.cpp2
-rw-r--r--storage/src/tests/storageserver/statemanagertest.cpp2
-rw-r--r--storage/src/vespa/storage/bucketdb/bucketmanager.cpp2
-rw-r--r--storage/src/vespa/storage/common/CMakeLists.txt1
-rw-r--r--storage/src/vespa/storage/common/nodestateupdater.h4
-rw-r--r--storage/src/vespa/storage/frameworkimpl/component/distributorcomponentregisterimpl.cpp2
-rw-r--r--storage/src/vespa/storage/persistence/bucketownershipnotifier.cpp2
-rw-r--r--storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp2
-rw-r--r--storage/src/vespa/storage/storageserver/bouncer.cpp2
-rw-r--r--storage/src/vespa/storage/storageserver/changedbucketownershiphandler.cpp2
-rw-r--r--storage/src/vespa/storage/storageserver/fnetlistener.cpp2
-rw-r--r--storage/src/vespa/storage/storageserver/mergethrottler.cpp2
-rw-r--r--storage/src/vespa/storage/storageserver/statemanager.cpp4
-rw-r--r--storage/src/vespa/storage/storageserver/statemanager.h3
-rw-r--r--storageapi/src/vespa/storageapi/message/state.cpp11
-rw-r--r--storageapi/src/vespa/storageapi/message/state.h16
-rw-r--r--vdslib/src/vespa/vdslib/state/CMakeLists.txt1
-rw-r--r--vdslib/src/vespa/vdslib/state/cluster_state_bundle.cpp (renamed from storage/src/vespa/storage/common/cluster_state_bundle.cpp)4
-rw-r--r--vdslib/src/vespa/vdslib/state/cluster_state_bundle.h (renamed from storage/src/vespa/storage/common/cluster_state_bundle.h)7
-rw-r--r--vespa-athenz/CMakeLists.txt2
-rw-r--r--vespa-athenz/pom.xml32
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentials.java (renamed from athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzCredentials.java)2
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentialsService.java (renamed from athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzCredentialsService.java)2
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImpl.java (renamed from athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzIdentityProviderImpl.java)5
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzService.java (renamed from athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzService.java)2
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/CryptoUtils.java (renamed from athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/CryptoUtils.java)2
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/IdentityDocumentService.java (renamed from athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/IdentityDocumentService.java)2
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceIdentity.java (renamed from athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceIdentity.java)2
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceRefreshInformation.java (renamed from athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceRefreshInformation.java)2
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceRegisterInformation.java (renamed from athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceRegisterInformation.java)2
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/SignedIdentityDocument.java (renamed from athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/SignedIdentityDocument.java)2
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/package-info.java (renamed from athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/package-info.java)2
-rw-r--r--vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImplTest.java (renamed from athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzIdentityProviderImplTest.java)24
-rw-r--r--vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/CryptoUtilsTest.java (renamed from athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/identityprovider/CryptoUtilsTest.java)2
-rw-r--r--vespajlib/src/main/java/com/yahoo/lang/MutableLong.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java35
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java6
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java86
192 files changed, 4072 insertions, 2955 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1fe91c233ab..22ebd6abb4c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -99,6 +99,7 @@ add_subdirectory(streamingvisitors)
add_subdirectory(vbench)
add_subdirectory(vdslib)
add_subdirectory(vdstestlib)
+add_subdirectory(vespa-athenz)
add_subdirectory(vespa-http-client)
add_subdirectory(vespa_jersey2)
add_subdirectory(vespabase)
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/SecretStoreKeyProvider.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/CkmsKeyProvider.java
index ac8c0eabf31..2f2cd5a8495 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/SecretStoreKeyProvider.java
+++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/CkmsKeyProvider.java
@@ -1,10 +1,10 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.athenz.instanceproviderservice.impl;
import com.google.inject.Inject;
import com.yahoo.athenz.auth.util.Crypto;
import com.yahoo.config.provision.Zone;
-import com.yahoo.jdisc.http.SecretStore;
+import com.yahoo.container.jdisc.Ckms;
import com.yahoo.vespa.hosted.athenz.instanceproviderservice.KeyProvider;
import com.yahoo.vespa.hosted.athenz.instanceproviderservice.config.AthenzProviderServiceConfig;
@@ -18,19 +18,20 @@ import static com.yahoo.vespa.hosted.athenz.instanceproviderservice.impl.Utils.g
/**
* @author mortent
+ * @author bjorncs
*/
@SuppressWarnings("unused") // Injected component
-public class SecretStoreKeyProvider implements KeyProvider {
+public class CkmsKeyProvider implements KeyProvider {
- private final SecretStore secretStore;
+ private final Ckms ckms;
private final String secretName;
private final Map<Integer, KeyPair> secrets;
@Inject
- public SecretStoreKeyProvider(SecretStore secretStore,
- Zone zone,
- AthenzProviderServiceConfig config) {
- this.secretStore = secretStore;
+ public CkmsKeyProvider(Ckms ckms,
+ Zone zone,
+ AthenzProviderServiceConfig config) {
+ this.ckms = ckms;
this.secretName = getZoneConfig(config, zone).secretName();
this.secrets = new HashMap<>();
}
@@ -59,7 +60,7 @@ public class SecretStoreKeyProvider implements KeyProvider {
// TODO: Consider moving to cryptoutils
private KeyPair readKeyPair(int version) {
- PrivateKey privateKey = Crypto.loadPrivateKey(secretStore.getSecret(secretName, version));
+ PrivateKey privateKey = Crypto.loadPrivateKey(ckms.getSecret(secretName, version));
PublicKey publicKey = Crypto.extractPublicKey(privateKey);
return new KeyPair(publicKey, privateKey);
}
diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java
index 0ff59c26c13..74ddd941afb 100644
--- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java
+++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java
@@ -165,19 +165,6 @@ public class ContentCluster {
}
}
- public StorageNodeStats getStorageNodeStats(int storageNodeIndex) {
- LatencyStats aggregatePutLatencyStats = new LatencyStats();
- StorageNodeStats aggregateStats = new StorageNodeStats(aggregatePutLatencyStats);
- for (DistributorNodeInfo distributor : clusterInfo.getDistributorNodeInfo()) {
- StorageNodeStats statsFromDistributor = distributor.getStorageNodeStatsOrNull(storageNodeIndex);
- if (statsFromDistributor != null) {
- aggregateStats.add(statsFromDistributor);
- }
- }
-
- return aggregateStats;
- }
-
/**
* Checks if a node can be upgraded
*
diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/DistributorNodeInfo.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/DistributorNodeInfo.java
index 575b965c0e5..a21fbd22213 100644
--- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/DistributorNodeInfo.java
+++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/DistributorNodeInfo.java
@@ -15,28 +15,8 @@ import com.yahoo.vespa.clustercontroller.core.hostinfo.StorageNodeStatsBridge;
*/
public class DistributorNodeInfo extends NodeInfo {
- private StorageNodeStatsContainer storageNodeStatsContainer = null;
-
public DistributorNodeInfo(ContentCluster cluster, int index, String rpcAddress, Distribution distribution) {
super(cluster, new Node(NodeType.DISTRIBUTOR, index), false, rpcAddress, distribution);
}
- @Override
- public void setHostInfo(HostInfo hostInfo) {
- // This affects getHostInfo(), and makes the host info available through NodeInfo.
- super.setHostInfo(hostInfo);
- storageNodeStatsContainer = StorageNodeStatsBridge.traverseHostInfo(hostInfo);
- }
-
- /**
- * @return Stats this distributor has about a storage node, or null if unknown.
- */
- public StorageNodeStats getStorageNodeStatsOrNull(int storageNodeIndex) {
- if (storageNodeStatsContainer == null) {
- return null;
- }
-
- return storageNodeStatsContainer.get(storageNodeIndex);
- }
-
}
diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/LatencyStats.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/LatencyStats.java
deleted file mode 100644
index 581cc244a20..00000000000
--- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/LatencyStats.java
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.clustercontroller.core;
-
-/**
- * LatencyStats handles adding latencies and counts.
- *
- * @author hakonhall
- */
-public class LatencyStats {
-
- private long latencyMsSum;
- private long count;
-
- public LatencyStats() { this(0, 0); }
-
- /**
- * @param latencyMsSum The sum of the latencies of all RPCs (or whatever) in milliseconds.
- * @param count The number of RPC calls (or whatever).
- */
- public LatencyStats(long latencyMsSum, long count) {
- this.latencyMsSum = latencyMsSum;
- this.count = count;
- }
-
- void add(LatencyStats latencyToAdd) {
- latencyMsSum += latencyToAdd.latencyMsSum;
- count += latencyToAdd.count;
- }
-
- public long getLatencyMsSum() { return latencyMsSum; }
- public long getCount() { return count; }
-
-}
diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodes.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodes.java
deleted file mode 100644
index 8df5820bc49..00000000000
--- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodes.java
+++ /dev/null
@@ -1,25 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.clustercontroller.core;
-
-import java.util.Map;
-
-/**
- * Contains stats for a set of storage nodes. This is used to store the stats returned
- * by Distributors from their getnodestate RPCs. The stats for a single storage node
- * is represented by the StorageNodeStats class.
- *
- * @author hakonhall
- */
-public class StatsForStorageNodes {
-
- final private Map<Integer, StorageNodeStats> storageNodesByIndex;
-
- StatsForStorageNodes(Map<Integer, StorageNodeStats> storageNodesByIndex) {
- this.storageNodesByIndex = storageNodesByIndex;
- }
-
- StorageNodeStats getStatsForStorageNode(int nodeIndex) {
- return storageNodesByIndex.get(nodeIndex);
- }
-
-}
diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStats.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStats.java
deleted file mode 100644
index d0afc1fa4b7..00000000000
--- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStats.java
+++ /dev/null
@@ -1,22 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.clustercontroller.core;
-
-/**
- * Contains stats related to a single storage node.
- *
- * @author hakonhall
- */
-public class StorageNodeStats {
-
- final private LatencyStats distributorPutLatency;
-
- /**
- * @param distributorPutLatency the "put" latency from the point of view of the distributor.
- */
- public StorageNodeStats(LatencyStats distributorPutLatency) { this.distributorPutLatency = distributorPutLatency; }
- public LatencyStats getDistributorPutLatency() { return distributorPutLatency; }
- public void add(StorageNodeStats statsToAdd) {
- distributorPutLatency.add(statsToAdd.distributorPutLatency);
- }
-
-}
diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainer.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainer.java
deleted file mode 100644
index 1fb24e72218..00000000000
--- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainer.java
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.clustercontroller.core;
-
-import java.util.HashMap;
-import java.util.Map;
-
-/**
- * Contains stats for a set of storage nodes. This is used to store the stats returned
- * by Distributors from their getnodestate RPCs. The stats for a single storage node
- * is represented by the StorageNodeStats class.
- *
- * @author hakonhall
- */
-public class StorageNodeStatsContainer {
-
- final private Map<Integer, StorageNodeStats> storageNodesByIndex = new HashMap<>();
-
- public void put(int nodeIndex, StorageNodeStats nodeStats) {
- storageNodesByIndex.put(nodeIndex, nodeStats);
- }
-
- public StorageNodeStats get(int nodeIndex) {
- return storageNodesByIndex.get(nodeIndex);
- }
-
- public int size() { return storageNodesByIndex.size(); }
-
-}
diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridge.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridge.java
index 55b7e4bb8c1..30ef0c69fe3 100644
--- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridge.java
+++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridge.java
@@ -16,31 +16,6 @@ public class StorageNodeStatsBridge {
private StorageNodeStatsBridge() { }
- public static StorageNodeStatsContainer traverseHostInfo(HostInfo hostInfo) {
- StorageNodeStatsContainer container = new StorageNodeStatsContainer();
- List<StorageNode> storageNodes = hostInfo.getDistributor().getStorageNodes();
- for (StorageNode storageNode : storageNodes) {
- Integer storageNodeIndex = storageNode.getIndex();
- if (storageNodeIndex == null) {
- continue;
- }
- StorageNode.OpsLatency opsLatency = storageNode.getOpsLatenciesOrNull();
- if (opsLatency == null) {
- continue;
- }
- StorageNode.Put putLatency = opsLatency.getPut();
- Long putLatencyMsSum = putLatency.getLatencyMsSum();
- Long putLatencyCount = putLatency.getCount();
- if (putLatencyMsSum == null || putLatencyCount == null) {
- continue;
- }
- LatencyStats putLatencyStats = new LatencyStats(putLatencyMsSum, putLatencyCount);
- StorageNodeStats nodeStats = new StorageNodeStats(putLatencyStats);
- container.put(storageNodeIndex, nodeStats);
- }
- return container;
- }
-
public static ContentClusterStats generate(Distributor distributor) {
Map<Integer, ContentNodeStats> mapToNodeStats = new HashMap<>();
for (StorageNode storageNode : distributor.getStorageNodes()) {
diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java
index 9c7143aed4a..669042c2fd8 100644
--- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java
+++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java
@@ -1,10 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.clustercontroller.core.restapiv2.requests;
-import com.yahoo.vespa.clustercontroller.core.LatencyStats;
import com.yahoo.vespa.clustercontroller.core.NodeInfo;
import com.yahoo.vespa.clustercontroller.core.RemoteClusterControllerTask;
-import com.yahoo.vespa.clustercontroller.core.StorageNodeStats;
import com.yahoo.vespa.clustercontroller.core.restapiv2.Id;
import com.yahoo.vespa.clustercontroller.core.restapiv2.Request;
import com.yahoo.vespa.clustercontroller.core.restapiv2.Response;
@@ -41,13 +39,6 @@ public class NodeStateRequest extends Request<Response.NodeResponse> {
result.addState("unit", new Response.UnitStateImpl(info.getReportedState()));
result.addState("user", new Response.UnitStateImpl(info.getWantedState()));
- if (info.isStorage() && verboseReports.contains(VerboseReport.STATISTICS)) {
- StorageNodeStats storageStats = context.cluster.getStorageNodeStats(info.getNodeIndex());
- LatencyStats latencyStats = storageStats.getDistributorPutLatency();
- result.addMetric("distributor-put-latency-ms-sum", latencyStats.getLatencyMsSum());
- result.addMetric("distributor-put-latency-count", latencyStats.getCount());
- }
-
for (int i=0; i<info.getReportedState().getDiskCount(); ++i) {
Id.Partition partitionId = new Id.Partition(id, i);
if (recursive > 0) {
diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodeTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodeTest.java
deleted file mode 100644
index 2e88c147095..00000000000
--- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodeTest.java
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.clustercontroller.core;
-
-import org.junit.Test;
-
-import java.util.HashMap;
-import java.util.Map;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
-
-/**
- * @author hakonhall
- */
-public class StatsForStorageNodeTest {
- @Test
- public void testStatsForStorage() {
- Map<Integer, StorageNodeStats> statsMap = new HashMap<>();
-
- LatencyStats putLatencyForA = new LatencyStats(1, 2);
- StorageNodeStats nodeStatsForA = new StorageNodeStats(putLatencyForA);
- statsMap.put(5, nodeStatsForA);
-
- LatencyStats putLatencyForB = new LatencyStats(3, 4);
- StorageNodeStats nodeStatsForB = new StorageNodeStats(putLatencyForB);
- statsMap.put(6, nodeStatsForB);
-
- StatsForStorageNodes stats = new StatsForStorageNodes(statsMap);
-
- StorageNodeStats nodeStats = stats.getStatsForStorageNode(5);
- assertNotNull(nodeStats);
- assertEquals(1, nodeStatsForA.getDistributorPutLatency().getLatencyMsSum());
- assertEquals(2, nodeStatsForA.getDistributorPutLatency().getCount());
-
- nodeStats = stats.getStatsForStorageNode(6);
- assertNotNull(nodeStats);
- assertEquals(3, nodeStatsForB.getDistributorPutLatency().getLatencyMsSum());
- assertEquals(4, nodeStatsForB.getDistributorPutLatency().getCount());
-
- nodeStats = stats.getStatsForStorageNode(7);
- assertNull(nodeStats);
- }
-}
diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainerTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainerTest.java
deleted file mode 100644
index 5107792dbff..00000000000
--- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainerTest.java
+++ /dev/null
@@ -1,43 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.clustercontroller.core;
-
-import org.junit.Test;
-
-import java.util.HashMap;
-import java.util.Map;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
-
-/**
- * @author hakonhall
- */
-public class StorageNodeStatsContainerTest {
- @Test
- public void testStatsForStorage() {
- StorageNodeStatsContainer statsContainer = new StorageNodeStatsContainer();
- Map<Integer, StorageNodeStats> statsMap = new HashMap<>();
-
- LatencyStats putLatencyForA = new LatencyStats(1, 2);
- StorageNodeStats nodeStatsForA = new StorageNodeStats(putLatencyForA);
- statsContainer.put(5, nodeStatsForA);
-
- LatencyStats putLatencyForB = new LatencyStats(3, 4);
- StorageNodeStats nodeStatsForB = new StorageNodeStats(putLatencyForB);
- statsContainer.put(6, nodeStatsForB);
-
- StorageNodeStats nodeStats = statsContainer.get(5);
- assertNotNull(nodeStats);
- assertEquals(1, nodeStatsForA.getDistributorPutLatency().getLatencyMsSum());
- assertEquals(2, nodeStatsForA.getDistributorPutLatency().getCount());
-
- nodeStats = statsContainer.get(6);
- assertNotNull(nodeStats);
- assertEquals(3, nodeStatsForB.getDistributorPutLatency().getLatencyMsSum());
- assertEquals(4, nodeStatsForB.getDistributorPutLatency().getCount());
-
- nodeStats = statsContainer.get(7);
- assertNull(nodeStats);
- }
-}
diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsTest.java
deleted file mode 100644
index 4defb015e76..00000000000
--- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsTest.java
+++ /dev/null
@@ -1,25 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.clustercontroller.core;
-
-import org.junit.Test;
-
-import static org.junit.Assert.assertEquals;
-
-/**
- * @author hakonhall
- */
-public class StorageNodeStatsTest {
- @Test
- public void testStorageNodeStats() {
- LatencyStats putLatency = new LatencyStats(1, 2);
- StorageNodeStats stats = new StorageNodeStats(putLatency);
- assertEquals(1, stats.getDistributorPutLatency().getLatencyMsSum());
- assertEquals(2, stats.getDistributorPutLatency().getCount());
-
- LatencyStats putLatencyToAdd = new LatencyStats(3, 4);
- StorageNodeStats statsToAdd = new StorageNodeStats(putLatencyToAdd);
- stats.add(statsToAdd);
- assertEquals(1 + 3, stats.getDistributorPutLatency().getLatencyMsSum());
- assertEquals(2 + 4, stats.getDistributorPutLatency().getCount());
- }
-}
diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridgeTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridgeTest.java
index 5319d741503..51e73b333c5 100644
--- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridgeTest.java
+++ b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridgeTest.java
@@ -3,8 +3,6 @@ package com.yahoo.vespa.clustercontroller.core.hostinfo;
import com.yahoo.vespa.clustercontroller.core.ContentNodeStats;
import com.yahoo.vespa.clustercontroller.core.ContentClusterStats;
-import com.yahoo.vespa.clustercontroller.core.StorageNodeStats;
-import com.yahoo.vespa.clustercontroller.core.StorageNodeStatsContainer;
import org.junit.Test;
import java.io.IOException;
@@ -31,24 +29,6 @@ public class StorageNodeStatsBridgeTest {
}
@Test
- public void testStorageNodeStatsContainer() throws IOException {
- String data = getJsonString();
- HostInfo hostInfo = HostInfo.createHostInfo(data);
- StorageNodeStatsContainer container = StorageNodeStatsBridge.traverseHostInfo(hostInfo);
- assertEquals(2, container.size());
-
- StorageNodeStats node0 = container.get(0);
- assertNotNull(node0);
- assertEquals(15, node0.getDistributorPutLatency().getLatencyMsSum());
- assertEquals(16, node0.getDistributorPutLatency().getCount());
-
- StorageNodeStats node1 = container.get(1);
- assertNotNull(node1);
- assertEquals(17, node1.getDistributorPutLatency().getLatencyMsSum());
- assertEquals(18, node1.getDistributorPutLatency().getCount());
- }
-
- @Test
public void testContentNodeStats() throws IOException {
String data = getJsonString();
HostInfo hostInfo = HostInfo.createHostInfo(data);
diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/restapiv2/NodeTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/restapiv2/NodeTest.java
index 1421e901048..de28867520b 100644
--- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/restapiv2/NodeTest.java
+++ b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/restapiv2/NodeTest.java
@@ -60,14 +60,6 @@ public class NodeTest extends StateRestApiTest {
" \"reason\": \"\"\n" +
" }\n" +
" },\n" +
- " \"metrics\": {\n" +
- // Why 24 and 28? There are 4 distributor nodes seen in slobrok (see StateRestApiTest).
- // Each gets a host info with distributor-put-latency-ms-sum 6 and
- // distributor-put-latency-count 7 (see StateRestApiTest.getHostInfo()).
- // Therefore, in aggregate, 4*6 is 24, and 4*7 is 28.
- " \"distributor-put-latency-ms-sum\": 24,\n" +
- " \"distributor-put-latency-count\": 28\n" +
- " },\n" +
" \"partition\": {\n" +
" \"0\": {\"link\": \"\\/cluster\\/v2\\/music\\/storage\\/1\\/0\"},\n" +
" \"1\": {\"link\": \"\\/cluster\\/v2\\/music\\/storage\\/1\\/1\"}\n" +
@@ -97,14 +89,6 @@ public class NodeTest extends StateRestApiTest {
" \"reason\": \"\"\n" +
" }\n" +
" },\n" +
- " \"metrics\": {\n" +
- // Why 24 and 28? There are 4 distributor nodes seen in slobrok (see StateRestApiTest).
- // Each gets a host info with distributor-put-latency-ms-sum 6 and
- // distributor-put-latency-count 7 (see StateRestApiTest.getHostInfo()).
- // Therefore, in aggregate, 4*6 is 24, and 4*7 is 28.
- " \"distributor-put-latency-ms-sum\": 24,\n" +
- " \"distributor-put-latency-count\": 28\n" +
- " },\n" +
" \"partition\": {\n" +
" \"0\": {\n" +
" \"state\": {\"generated\": {\n" +
@@ -158,10 +142,6 @@ public class NodeTest extends StateRestApiTest {
" \"state\": \"up\",\n" +
" \"reason\": \"\"\n" +
" }\n" +
- " },\n" +
- " \"metrics\": {\n" +
- " \"distributor-put-latency-ms-sum\": 0,\n" +
- " \"distributor-put-latency-count\": 0\n" +
" }\n" +
"}";
assertEquals(expected, jsonWriter.createJson(response).toString(2));
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
index d6b916680d8..bd94f67e4a7 100644
--- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
+++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
@@ -323,7 +323,7 @@ public class DeployState implements ConfigDefinitionStore {
closeIgnoreException(reader.getReader());
}
}
- builder.build(logger, queryProfiles);
+ builder.build(logger);
return SearchDocumentModel.fromBuilderAndNames(builder, names);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java b/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java
index dd03cb8b2a7..dc59d9cb3e5 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java
@@ -5,11 +5,10 @@
*/
package com.yahoo.searchdefinition;
-import java.util.Arrays;
-import java.util.List;
+import com.yahoo.searchlib.rankingexpression.Reference;
+
import java.util.Optional;
import java.util.regex.Pattern;
-import java.util.stream.Collectors;
/**
* Utility methods for query, document and constant rank feature names
@@ -20,85 +19,16 @@ public class FeatureNames {
private static final Pattern identifierRegexp = Pattern.compile("[A-Za-z0-9_][A-Za-z0-9_-]*");
- /**
- * <p>Returns the given query, document or constant feature in canonical form.
- * A feature name consists of a feature type name (query, attribute or constant),
- * followed by one argument enclosed in quotes.
- * The argument may be an identifier or any string single or double quoted.</p>
- *
- * <p>Argument string values may not contain comma, single quote nor double quote characters.</p>
- *
- * <p><i>The canonical form use no quotes for arguments which are identifiers, and double quotes otherwise.</i></p>
- *
- * <p>Note that the above definition is not true for features in general, which accept any ranking expression
- * as argument.</p>
- *
- * @throws IllegalArgumentException if the feature name is not valid
- */
- // Note that this implementation is more general than what is described above:
- // It accepts any number of arguments and an optional output
- public static String canonicalize(String feature) {
- return canonicalizeIfValid(feature).orElseThrow(() ->
- new IllegalArgumentException("A feature name must be on the form query(name), attribute(name) or " +
- "constant(name), but was '" + feature + "'"
- ));
- }
-
- /**
- * Canonicalizes the given argument as in canonicalize, but returns empty instead of throwing an exception if
- * the argument is not a valid feature
- */
- public static Optional<String> canonicalizeIfValid(String feature) {
- int startParenthesis = feature.indexOf('(');
- if (startParenthesis < 0)
- return Optional.empty();
- int endParenthesis = feature.lastIndexOf(')');
- String featureType = feature.substring(0, startParenthesis);
- if ( ! ( featureType.equals("query") || featureType.equals("attribute") || featureType.equals("constant")))
- return Optional.empty();
- if (startParenthesis < 1) return Optional.of(feature); // No arguments
- if (endParenthesis < startParenthesis)
- return Optional.empty();
- String argumentString = feature.substring(startParenthesis + 1, endParenthesis);
- List<String> canonicalizedArguments =
- Arrays.stream(argumentString.split(","))
- .map(FeatureNames::canonicalizeArgument)
- .collect(Collectors.toList());
- return Optional.of(featureType + "(" +
- canonicalizedArguments.stream().collect(Collectors.joining(",")) +
- feature.substring(endParenthesis));
- }
-
- /** Canomicalizes a single argument */
- private static String canonicalizeArgument(String argument) {
- if (argument.startsWith("'")) {
- if ( ! argument.endsWith("'"))
- throw new IllegalArgumentException("Feature arguments starting by a single quote " +
- "must end by a single quote, but was \"" + argument + "\"");
- argument = argument.substring(1, argument.length() - 1);
- }
- if (argument.startsWith("\"")) {
- if ( ! argument.endsWith("\""))
- throw new IllegalArgumentException("Feature arguments starting by a double quote " +
- "must end by a double quote, but was '" + argument + "'");
- argument = argument.substring(1, argument.length() - 1);
- }
- if (identifierRegexp.matcher(argument).matches())
- return argument;
- else
- return "\"" + argument + "\"";
- }
-
- public static String asConstantFeature(String constantName) {
- return canonicalize("constant(\"" + constantName + "\")");
+ public static Reference asConstantFeature(String constantName) {
+ return Reference.simple("constant", quoteIfNecessary(constantName));
}
- public static String asAttributeFeature(String attributeName) {
- return canonicalize("attribute(\"" + attributeName + "\")");
+ public static Reference asAttributeFeature(String attributeName) {
+ return Reference.simple("attribute", quoteIfNecessary(attributeName));
}
- public static String asQueryFeature(String propertyName) {
- return canonicalize("query(\"" + propertyName + "\")");
+ public static Reference asQueryFeature(String propertyName) {
+ return Reference.simple("query", quoteIfNecessary(propertyName));
}
/**
@@ -106,15 +36,21 @@ public class FeatureNames {
* or empty if it is not a valid query, attribute or constant feature name
*/
public static Optional<String> argumentOf(String feature) {
- return canonicalizeIfValid(feature).map(f -> {
- int startParenthesis = f.indexOf("(");
- int endParenthesis = f.indexOf(")");
- String possiblyQuotedArgument = f.substring(startParenthesis + 1, endParenthesis);
- if (possiblyQuotedArgument.startsWith("\""))
- return possiblyQuotedArgument.substring(1, possiblyQuotedArgument.length() - 1);
- else
- return possiblyQuotedArgument;
- });
+ Optional<Reference> reference = Reference.simple(feature);
+ if ( ! reference.isPresent()) return Optional.empty();
+ if ( ! ( reference.get().name().equals("attribute") ||
+ reference.get().name().equals("constant") ||
+ reference.get().name().equals("query")))
+ return Optional.empty();
+
+ return Optional.of(reference.get().arguments().expressions().get(0).toString());
+ }
+
+ private static String quoteIfNecessary(String s) {
+ if (identifierRegexp.matcher(s).matches())
+ return s;
+ else
+ return "\"" + s + "\"";
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
new file mode 100644
index 00000000000..cf6d90db7fa
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -0,0 +1,167 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition;
+
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext;
+import com.yahoo.searchlib.rankingexpression.rule.NameNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * A context which only contains type information.
+ * This returns empty tensor types (double) for unknown features which are not
+ * query, attribute or constant features, as we do not have information about which such
+ * features exist (but we know those that exist are doubles).
+ *
+ * @author bratseth
+ */
+public class MapEvaluationTypeContext extends FunctionReferenceContext implements TypeContext<Reference> {
+
+ private final Map<Reference, TensorType> featureTypes = new HashMap<>();
+
+ public MapEvaluationTypeContext(Collection<ExpressionFunction> functions) {
+ super(functions);
+ }
+
+ public MapEvaluationTypeContext(Map<String, ExpressionFunction> functions,
+ Map<String, String> bindings,
+ Map<Reference, TensorType> featureTypes) {
+ super(functions, bindings);
+ this.featureTypes.putAll(featureTypes);
+ }
+
+ public void setType(Reference reference, TensorType type) {
+ featureTypes.put(reference, type);
+ }
+
+ @Override
+ public TensorType getType(String reference) {
+ throw new UnsupportedOperationException("Not able to parse gereral references from string form");
+ }
+
+ @Override
+ public TensorType getType(Reference reference) {
+ Optional<String> binding = boundIdentifier(reference);
+ if (binding.isPresent()) {
+ try {
+ // This is not pretty, but changing to bind expressions rather
+ // than their string values requires deeper changes
+ return new RankingExpression(binding.get()).type(this);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+ if (isSimpleFeature(reference)) {
+ // The argument may be a local identifier bound to the actual value
+ String argument = simpleArgument(reference.arguments()).get();
+ reference = Reference.simple(reference.name(), bindings.getOrDefault(argument, argument));
+ return featureTypes.getOrDefault(reference, defaultTypeOf(reference));
+ }
+
+ Optional<ExpressionFunction> function = functionInvocation(reference);
+ if (function.isPresent()) {
+ return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments())));
+ }
+
+ // We do not know what this is - since we do not have complete knowledge abut the match features
+ // in Java we must assume this is a match feature and return the double type - which is the type of all
+ // all match features
+ return TensorType.empty;
+ }
+
+ /**
+ * Returns the default type for this simple feature, or nullif it does not have a default
+ */
+ public TensorType defaultTypeOf(Reference reference) {
+ if ( ! isSimpleFeature(reference))
+ throw new IllegalArgumentException("This can only be called for simple references, not " + reference);
+ if (reference.name().equals("query")) // we do not require all query features to be declared, only non-doubles
+ return TensorType.empty;
+ return null;
+ }
+
+ /**
+ * Returns the binding if this reference is a simple identifier which is bound in this context.
+ * Returns empty otherwise.
+ */
+ private Optional<String> boundIdentifier(Reference reference) {
+ if ( ! reference.arguments().isEmpty()) return Optional.empty();
+ if ( reference.output() != null) return Optional.empty();
+ return Optional.ofNullable(bindings.get(reference.name()));
+ }
+
+ /**
+ * Return whether the reference (discarding the output) is a simple feature
+ * ("attribute(name)", "constant(name)" or "query(name)").
+ * We disregard the output because all outputs under a simple feature have the same type.
+ */
+ private boolean isSimpleFeature(Reference reference) {
+ Optional<String> argument = simpleArgument(reference.arguments());
+ if ( ! argument.isPresent()) return false;
+ return reference.name().equals("attribute") ||
+ reference.name().equals("constant") ||
+ reference.name().equals("query");
+ }
+
+ /**
+ * If these arguments contains one simple argument string, it is returned.
+ * Otherwise null is returned.
+ */
+ private Optional<String> simpleArgument(Arguments arguments) {
+ if (arguments.expressions().size() != 1) return Optional.empty();
+ ExpressionNode argument = arguments.expressions().get(0);
+
+ if ( ! (argument instanceof ReferenceNode)) return Optional.empty();
+ ReferenceNode refArgument = (ReferenceNode)argument;
+
+ if ( ! refArgument.reference().isIdentifier()) return Optional.empty();
+
+ return Optional.of(refArgument.getName());
+ }
+
+ private Optional<ExpressionFunction> functionInvocation(Reference reference) {
+ if (reference.output() != null) return Optional.empty();
+ ExpressionFunction function = functions().get(reference.name());
+ if (function == null) return Optional.empty();
+ if (function.arguments().size() != reference.arguments().size()) return Optional.empty();
+ return Optional.of(function);
+ }
+
+ /** Binds the given list of formal arguments to their actual values */
+ private Map<String, String> bind(List<String> formalArguments,
+ Arguments invocationArguments) {
+ Map<String, String> bindings = new HashMap<>(formalArguments.size());
+ for (int i = 0; i < formalArguments.size(); i++) {
+ String identifier = invocationArguments.expressions().get(i).toString();
+ identifier = super.bindings.getOrDefault(identifier, identifier);
+ bindings.put(formalArguments.get(i), identifier);
+ }
+ return bindings;
+ }
+
+ public Map<Reference, TensorType> featureTypes() {
+ return Collections.unmodifiableMap(featureTypes);
+ }
+
+ @Override
+ public MapEvaluationTypeContext withBindings(Map<String, String> bindings) {
+ if (bindings.isEmpty() && this.bindings.isEmpty()) return this;
+ return new MapEvaluationTypeContext(functions(), bindings, featureTypes);
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index bcbc7cc99e2..064897de8dc 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -2,24 +2,18 @@
package com.yahoo.searchdefinition;
import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.model.deploy.DeployState;
-import com.yahoo.io.reader.NamedReader;
-import com.yahoo.processing.request.CompoundName;
-import com.yahoo.search.query.profile.QueryProfile;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.search.query.profile.config.QueryProfileXMLReader;
import com.yahoo.search.query.profile.types.FieldDescription;
import com.yahoo.search.query.profile.types.QueryProfileType;
-import com.yahoo.search.query.profile.types.TensorFieldType;
import com.yahoo.search.query.ranking.Diversity;
-import com.yahoo.searchdefinition.document.SDField;
+import com.yahoo.searchdefinition.document.ImmutableSDField;
import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext;
import com.yahoo.searchdefinition.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.FeatureList;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TypeMapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
@@ -39,7 +33,10 @@ import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
import java.util.Set;
+import java.util.stream.Collectors;
/**
* Represents a rank profile - a named set of ranking settings
@@ -363,14 +360,14 @@ public class RankProfile implements Serializable, Cloneable {
/** Returns a read-only view of the summary features to use in this profile. This is never null */
public Set<ReferenceNode> getSummaryFeatures() {
- if (summaryFeatures!=null) return Collections.unmodifiableSet(summaryFeatures);
- if (getInherited()!=null) return getInherited().getSummaryFeatures();
+ if (summaryFeatures != null) return Collections.unmodifiableSet(summaryFeatures);
+ if (getInherited() != null) return getInherited().getSummaryFeatures();
return Collections.emptySet();
}
public void addSummaryFeature(ReferenceNode feature) {
- if (summaryFeatures==null)
- summaryFeatures=new LinkedHashSet<>();
+ if (summaryFeatures == null)
+ summaryFeatures = new LinkedHashSet<>();
summaryFeatures.add(feature);
}
@@ -585,8 +582,11 @@ public class RankProfile implements Serializable, Cloneable {
}
/**
- * Will take the parser-set textual ranking expressions and turn into objects
+ * Will take the parser-set textual ranking expressions and turn into ranking expression objects,
+ * if not already done
*/
+ // TODO: There doesn't appear to be any good reason to defer parsing of ranking expressions
+ // until this is called. Simplify by parsing them right away.
public void parseExpressions() {
try {
parseRankingExpressions();
@@ -604,20 +604,23 @@ public class RankProfile implements Serializable, Cloneable {
for (Map.Entry<String, Macro> e : getMacros().entrySet()) {
String macroName = e.getKey();
Macro macro = e.getValue();
- RankingExpression expr = parseRankingExpression(macroName, macro.getTextualExpression());
- macro.setRankingExpression(expr);
- macro.setTextualExpression(expr.getRoot().toString());
+ if (macro.getRankingExpression() == null) {
+ RankingExpression expr = parseRankingExpression(macroName, macro.getTextualExpression());
+ macro.setRankingExpression(expr);
+ macro.setTextualExpression(expr.getRoot().toString());
+ }
}
}
/**
* Passes ranking expressions on to parser
+ *
* @throws ParseException if either of the ranking expressions could not be parsed
*/
private void parseRankingExpressions() throws ParseException {
- if (getFirstPhaseRankingString() != null)
+ if (getFirstPhaseRankingString() != null && firstPhaseRanking == null)
setFirstPhaseRanking(parseRankingExpression("firstphase", getFirstPhaseRankingString()));
- if (getSecondPhaseRankingString() != null)
+ if (getSecondPhaseRankingString() != null && secondPhaseRanking == null)
setSecondPhaseRanking(parseRankingExpression("secondphase", getSecondPhaseRankingString()));
}
@@ -748,37 +751,50 @@ public class RankProfile implements Serializable, Cloneable {
* referable from this rank profile.
*/
public TypeContext typeContext(QueryProfileRegistry queryProfiles) {
- TypeMapContext context = new TypeMapContext();
+ MapEvaluationTypeContext context = new MapEvaluationTypeContext(getMacros().values().stream()
+ .map(Macro::asExpressionFunction)
+ .collect(Collectors.toList()));
- // Add small constants
+ // Add small and large constants, respectively
getConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.type()));
- // Add large constants
getSearch().getRankingConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.getTensorType()));
// Add attributes
- for (SDField field : getSearch().allConcreteFields()) {
- field.getAttributes().forEach((k, a) -> context.setType(FeatureNames.asAttributeFeature(k), a.tensorType().orElse(TensorType.empty)));
- }
+ getSearch().allFields().forEach(field -> addAttributeFeatureTypes(field, context));
+ getSearch().allImportedFields().forEach(field -> addAttributeFeatureTypes(field, context));
// Add query features from rank profile types reached from the "default" profile
for (QueryProfileType queryProfileType : queryProfiles.getTypeRegistry().allComponents()) {
for (FieldDescription field : queryProfileType.declaredFields().values()) {
TensorType type = field.getType().asTensorType();
- String feature = FeatureNames.asQueryFeature(field.getName());
- TensorType existingType = context.getType(feature);
- if (existingType != null)
+ Optional<Reference> feature = Reference.simple(field.getName());
+ if ( ! feature.isPresent() || ! feature.get().name().equals("query")) continue;
+
+ TensorType existingType = context.getType(feature.get());
+ if ( ! Objects.equals(existingType, context.defaultTypeOf(feature.get())))
type = existingType.dimensionwiseGeneralizationWith(type).orElseThrow( () ->
- new IllegalArgumentException(queryProfileType + " contains query feature " + feature +
+ new IllegalArgumentException(queryProfileType + " contains query feature " + feature.get() +
" with type " + field.getType().asTensorType() +
", but this is already defined " +
- "in another query profile with type " + context.getType(feature)));
- context.setType(feature, type);
+ "in another query profile with type " +
+ context.getType(feature.get())));
+ context.setType(feature.get(), type);
}
}
return context;
}
+ private void addAttributeFeatureTypes(ImmutableSDField field, MapEvaluationTypeContext context) {
+ field.getAttributes().forEach((k, a) -> {
+ String name = k;
+ if (k.equals(field.getBackingField().getName())) // this attribute should take the fields name
+ name = field.getName(); // switch to that - it is separate for imported fields
+ context.setType(FeatureNames.asAttributeFeature(name),
+ a.tensorType().orElse(TensorType.empty));
+ });
+ }
+
/**
* A rank setting. The identity of a rank setting is its field name and type (not value).
* A rank setting is immutable.
@@ -910,7 +926,7 @@ public class RankProfile implements Serializable, Cloneable {
*/
public static class Macro implements Serializable, Cloneable {
- private String name=null;
+ private final String name;
private String textualExpression=null;
private RankingExpression expression=null;
private List<String> formalParams = new ArrayList<>();
@@ -955,7 +971,7 @@ public class RankProfile implements Serializable, Cloneable {
return inline && formalParams.size() == 0; // only inline no-arg macros;
}
- public ExpressionFunction toExpressionMacro() {
+ public ExpressionFunction asExpressionFunction() {
return new ExpressionFunction(getName(), getFormalParams(), getRankingExpression());
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java
index a075b9d00fa..7b4d70d85b1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java
@@ -16,8 +16,7 @@ import java.util.Set;
* Having both of these mappings consolidated here make it easier to remove dependencies on these mappings at
* run time, since it is essentially only used when building rank profile config at deployment time.
*
- * TODO: Reconsider the difference between local and global maps. Right now, the local maps might better be
- * served from a different class owned by SearchBuilder.
+ * TODO: Rank profiles should be stored under its owning Search instance.
*
* @author Ulf Lilleengen
*/
@@ -31,9 +30,6 @@ public class RankProfileRegistry {
/* These rank profiles can be overridden: 'default' rank profile, as that is documented to work. And 'unranked'. */
static final Set<String> overridableRankProfileNames = new HashSet<>(Arrays.asList("default", "unranked"));
- public RankProfileRegistry() {
- }
-
public static RankProfileRegistry createRankProfileRegistryWithBuiltinRankProfiles(Search search) {
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
rankProfileRegistry.addRankProfile(new DefaultRankProfile(search, rankProfileRegistry));
@@ -47,7 +43,7 @@ public class RankProfileRegistry {
* @param rankProfile the rank profile to add
*/
public void addRankProfile(RankProfile rankProfile) {
- if (!rankProfiles.containsKey(rankProfile.getSearch())) {
+ if ( ! rankProfiles.containsKey(rankProfile.getSearch())) {
rankProfiles.put(rankProfile.getSearch(), new LinkedHashMap<>());
}
checkForDuplicateRankProfile(rankProfile);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
index f4a0365e36e..1ab76afc9c0 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
@@ -199,9 +199,7 @@ public class Search implements Serializable, ImmutableSearch {
@Override
public ImmutableSDField getField(String name) {
ImmutableSDField field = getConcreteField(name);
- if (field != null) {
- return field;
- }
+ if (field != null) return field;
return allImportedFields()
.filter(f -> f.getName().equals(name))
.findFirst()
@@ -248,8 +246,6 @@ public class Search implements Serializable, ImmutableSearch {
* Returns a list of all the fields of this search definition, that is all fields in all documents, in the documents
* they inherit, and all extra fields. The caller receives ownership to the list - subsequent changes to it will not
* impact this
- *
- * @return the list of fields in this searchdefinition
*/
public List<SDField> allConcreteFields() {
List<SDField> allFields = new ArrayList<>();
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java
index 762c0fec838..e7cd21ac834 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java
@@ -18,6 +18,7 @@ import com.yahoo.searchdefinition.parser.TokenMgrError;
import com.yahoo.searchdefinition.processing.Processing;
import com.yahoo.vespa.documentmodel.DocumentModel;
import com.yahoo.vespa.model.container.search.QueryProfiles;
+import com.yahoo.yolean.Exceptions;
import java.io.File;
import java.io.IOException;
@@ -34,7 +35,6 @@ import java.util.List;
* expressions, using the setRankXXX() methods, 3) invoke the {@link #build()} method, and 4) retrieve the built
* search objects using the {@link #getSearch(String)} method.
*/
-// TODO: This should be cleaned up and more or maybe completely taken over by MockApplicationPackage
public class SearchBuilder {
private final DocumentTypeManager docTypeMgr = new DocumentTypeManager();
@@ -154,7 +154,7 @@ public class SearchBuilder {
} catch (TokenMgrError e) {
throw new ParseException("Unknown symbol: " + e.getMessage());
} catch (ParseException pe) {
- throw new ParseException(stream.formatException(pe.getMessage()));
+ throw new ParseException(stream.formatException(Exceptions.toMessageString(pe)));
}
return importRawSearch(search);
}
@@ -196,11 +196,7 @@ public class SearchBuilder {
* @throws IllegalStateException Thrown if this method has already been called.
*/
public void build() {
- build(new BaseDeployLogger(), new QueryProfiles());
- }
-
- public void build(DeployLogger logger) {
- build(logger, new QueryProfiles());
+ build(new BaseDeployLogger());
}
/**
@@ -209,12 +205,10 @@ public class SearchBuilder {
*
* @throws IllegalStateException Thrown if this method has already been called.
* @param deployLogger The logger to use during build
- * @param queryProfiles The query profiles contained in the application this search is part of.
*/
- public void build(DeployLogger deployLogger, QueryProfiles queryProfiles) {
- if (isBuilt) {
- throw new IllegalStateException("Searches already built.");
- }
+ public void build(DeployLogger deployLogger) {
+ if (isBuilt) throw new IllegalStateException("Model already built");
+
List<Search> built = new ArrayList<>();
List<SDDocumentType> sdocs = new ArrayList<>();
sdocs.add(SDDocumentType.VESPA_DOCUMENT);
@@ -240,7 +234,7 @@ public class SearchBuilder {
for (Search search : new SearchOrderer().order(searchList)) {
new FieldOperationApplierForSearch().process(search);
// These two needed for a couple of old unit tests, ideally these are just read from app
- process(search, deployLogger, queryProfiles);
+ process(search, deployLogger, new QueryProfiles(queryProfileRegistry));
built.add(search);
}
builder.addToModel(searchList);
@@ -254,8 +248,6 @@ public class SearchBuilder {
/**
* Processes and returns the given {@link Search} object. This method has been factored out of the {@link
* #build()} method so that subclasses can choose not to build anything.
- *
- * @param search The object to build.
*/
protected void process(Search search, DeployLogger deployLogger, QueryProfiles queryProfiles) {
Processing.process(search, deployLogger, rankProfileRegistry, queryProfiles);
@@ -352,7 +344,7 @@ public class SearchBuilder {
rankProfileRegistry,
queryprofileRegistry);
builder.importFile(fileName);
- builder.build(deployLogger, new QueryProfiles());
+ builder.build(deployLogger);
return builder;
}
@@ -368,7 +360,7 @@ public class SearchBuilder {
for (Iterator<Path> i = Files.list(new File(dir).toPath()).filter(p -> p.getFileName().toString().endsWith(".sd")).iterator(); i.hasNext(); ) {
builder.importFile(i.next());
}
- builder.build(new BaseDeployLogger(), new QueryProfiles());
+ builder.build(new BaseDeployLogger());
return builder;
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java
deleted file mode 100644
index 40e9db1413f..00000000000
--- a/config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchdefinition;
-
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-
-/**
- * A context which only contains type information.
- *
- * @author bratseth
- */
-public class TypeMapContext implements TypeContext {
-
- private final Map<String, TensorType> featureTypes = new HashMap<>();
-
- public void setType(String name, TensorType type) {
- featureTypes.put(FeatureNames.canonicalize(name), type);
- }
-
- @Override
- public TensorType getType(String name) {
- return featureTypes.get(FeatureNames.canonicalize(name));
- }
-
- /** Returns an unmodifiable map of the bindings in this */
- public Map<String, TensorType> bindings() { return Collections.unmodifiableMap(featureTypes); }
-
-}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index ea02f960800..b02362154d9 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -188,7 +188,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
if (macros.isEmpty()) return;
Map<String, ExpressionFunction> expressionMacros = new LinkedHashMap<>();
for (Map.Entry<String, RankProfile.Macro> macro : macros.entrySet()) {
- expressionMacros.put(macro.getKey(), macro.getValue().toExpressionMacro());
+ expressionMacros.put(macro.getKey(), macro.getValue().asExpressionFunction());
}
Map<String, String> macroProperties = new LinkedHashMap<>();
@@ -223,7 +223,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
// Is the feature a macro?
if (context.getFunction(referenceNode.getName()) != null) {
context.addFunctionSerialization(RankingExpression.propertyName(referenceNode.getName()),
- referenceNode.toString(context, null, null));
+ referenceNode.toString(context, null, null));
ReferenceNode newReferenceNode = new ReferenceNode("rankingExpression(" + referenceNode.getName() + ")", referenceNode.getArguments().expressions(), referenceNode.getOutput());
macroSummaryFeatures.put(referenceNode.getName(), newReferenceNode);
i.remove(); // Will add the expanded one in next block
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java
index 8b6df1a87db..4502468379f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java
@@ -63,6 +63,9 @@ public class ImmutableImportedSDField implements ImmutableSDField {
}
@Override
+ public ImmutableSDField getBackingField() { return importedField.targetField(); }
+
+ @Override
public boolean isIndexStructureField() {
return importedField.targetField().isIndexStructureField();
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableSDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableSDField.java
index 152690a6f56..70553d4b57c 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableSDField.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableSDField.java
@@ -19,6 +19,7 @@ import java.util.Map;
* @author bjorncs
*/
public interface ImmutableSDField {
+
<T extends Expression> boolean containsExpression(Class<T> searchFor);
boolean doesAttributing();
@@ -33,6 +34,12 @@ public interface ImmutableSDField {
boolean isImportedField();
+ /**
+ * Returns the field backing this - the field itself if this is a regular field,
+ * and the target field if this is imported.
+ */
+ ImmutableSDField getBackingField();
+
default boolean isConcreteField() {
return !isImportedField();
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java
index 593edc33370..6e7582a98c8 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java
@@ -209,6 +209,9 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer,
}
@Override
+ public ImmutableSDField getBackingField() { return this; }
+
+ @Override
public boolean doesAttributing() {
return containsExpression(AttributeExpression.class);
}
@@ -623,8 +626,7 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer,
public RankType getRankType() { return this.rankType; }
/**
- * <p>Returns the search-time attribute settings of this field
- * or null if none is set.</p>
+ * Returns the search-time attribute settings of this field or null if none is set.
*
* <p>TODO: Make unmodifiable.</p>
*/
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 2b997aa25f2..f16697b5ba6 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -208,6 +208,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
throw new IllegalArgumentException("Model refers Placeholder '" + macroName +
"' of type " + requiredType + " but this macro is not present in " +
profile);
+ // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
+ // phase and summary features), as it may only resolve correctly given those bindings
+ // Or, probably better, annotate the macros with type constraints here and verify during general
+ // type verification
TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
if ( actualType == null)
throw new IllegalArgumentException("Model refers Placeholder '" + macroName +
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java
index ee65c9bec02..cc634abef01 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java
@@ -13,7 +13,7 @@ import com.yahoo.vespa.indexinglanguage.expressions.OutputExpression;
import com.yahoo.vespa.model.container.search.QueryProfiles;
/**
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen Hult</a>
+ * @author Simon Thoresen Hult
*/
public class IndexingValues extends Processor {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
index 90183848094..061a803cb48 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
@@ -76,8 +76,9 @@ public class Processing {
ImportedFieldsInSummayValidator::new,
FastAccessValidator::new,
ReservedMacroNames::new,
+ RankingExpressionTypeValidator::new,
- // These two should be last.
+ // These should be last.
IndexingValidation::new,
IndexingValues::new);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java
new file mode 100644
index 00000000000..baacceea667
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java
@@ -0,0 +1,82 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.config.application.api.DeployLogger;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.vespa.model.container.search.QueryProfiles;
+
+/**
+ * Validates the types of all ranking expressions under a search instance:
+ * Some operators constrain the types of inputs, and first-and second-phase expressions
+ * must return scalar values. In addition, the existence of all referred attribute, query and constant
+ * features is ensured.
+ *
+ * @author bratseth
+ */
+public class RankingExpressionTypeValidator extends Processor {
+
+ private final QueryProfileRegistry queryProfiles;
+
+ public RankingExpressionTypeValidator(Search search,
+ DeployLogger deployLogger,
+ RankProfileRegistry rankProfileRegistry,
+ QueryProfiles queryProfiles) {
+ super(search, deployLogger, rankProfileRegistry, queryProfiles);
+ this.queryProfiles = queryProfiles.getRegistry();
+ }
+
+ @Override
+ public void process() {
+ for (RankProfile profile : rankProfileRegistry.localRankProfiles(search)) {
+ try {
+ validate(profile);
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("In " + search + ", " + profile, e);
+ }
+ }
+ }
+
+ /** Throws an IllegalArgumentException if the given rank profile does not produce valid type */
+ private void validate(RankProfile profile) {
+ profile.parseExpressions();
+ TypeContext context = profile.typeContext(queryProfiles);
+ profile.getSummaryFeatures().forEach(f -> ensureValid(f, "summary feature " + f, context));
+ ensureValidDouble(profile.getFirstPhaseRanking(), "first-phase expression", context);
+ ensureValidDouble(profile.getSecondPhaseRanking(), "second-phase expression", context);
+ }
+
+ private TensorType ensureValid(RankingExpression expression, String expressionDescription, TypeContext context) {
+ if (expression == null) return null;
+ return ensureValid(expression.getRoot(), expressionDescription, context);
+ }
+
+ private TensorType ensureValid(ExpressionNode expression, String expressionDescription, TypeContext context) {
+ TensorType type;
+ try {
+ type = expression.type(context);
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("The " + expressionDescription + " is invalid", e);
+ }
+ if (type == null) // Not expected to happen
+ throw new IllegalStateException("Could not determine the type produced by " + expressionDescription);
+ return type;
+ }
+
+ private void ensureValidDouble(RankingExpression expression, String expressionDescription, TypeContext context) {
+ if (expression == null) return;
+ TensorType type = ensureValid(expression, expressionDescription, context);
+ if ( ! type.equals(TensorType.empty))
+ throw new IllegalArgumentException("The " + expressionDescription + " must produce a double " +
+ "(a tensor with no dimensions), but produces " + type);
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java b/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java
index 640a85d9b50..15b482ee60c 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java
@@ -4,21 +4,24 @@ package com.yahoo.vespa.model.container;
import com.yahoo.config.provision.AthenzDomain;
import com.yahoo.config.provision.AthenzService;
import com.yahoo.config.provision.HostName;
+import com.yahoo.container.bundle.BundleInstantiationSpecification;
import com.yahoo.container.core.identity.IdentityConfig;
+import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.vespa.model.container.component.SimpleComponent;
/**
* @author mortent
*/
public class IdentityProvider extends SimpleComponent implements IdentityConfig.Producer {
- public static final String CLASS = "com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl";
+ public static final String CLASS = "com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl";
+ public static final String BUNDLE = "vespa-athenz";
private final AthenzDomain domain;
private final AthenzService service;
private final HostName loadBalancerName;
public IdentityProvider(AthenzDomain domain, AthenzService service, HostName loadBalancerName) {
- super(CLASS);
+ super(new ComponentModel(BundleInstantiationSpecification.getFromStrings(CLASS, CLASS, BUNDLE)));
this.domain = domain;
this.service = service;
this.loadBalancerName = loadBalancerName;
diff --git a/config-model/src/test/derived/rankexpression/rank-profiles.cfg b/config-model/src/test/derived/rankexpression/rank-profiles.cfg
index e890b75770b..f5652c31d2a 100644
--- a/config-model/src/test/derived/rankexpression/rank-profiles.cfg
+++ b/config-model/src/test/derived/rankexpression/rank-profiles.cfg
@@ -24,7 +24,7 @@ rankprofile[0].fef.property[10].value "4"
rankprofile[0].fef.property[11].name "vespa.dump.feature"
rankprofile[0].fef.property[11].value "attribute(foo1).out"
rankprofile[0].fef.property[12].name "vespa.dump.feature"
-rankprofile[0].fef.property[12].value "attribute(bar1.out)"
+rankprofile[0].fef.property[12].value "attribute(bar1)"
rankprofile[0].fef.property[13].name "vespa.dump.feature"
rankprofile[0].fef.property[13].value "attribute(foo2).out"
rankprofile[0].fef.property[14].name "vespa.dump.feature"
@@ -64,7 +64,7 @@ rankprofile[2].fef.property[2].value "10 + feature(arg1).out.out"
rankprofile[2].fef.property[3].name "vespa.summary.feature"
rankprofile[2].fef.property[3].value "attribute(foo1).out"
rankprofile[2].fef.property[4].name "vespa.summary.feature"
-rankprofile[2].fef.property[4].value "attribute(bar1.out)"
+rankprofile[2].fef.property[4].value "attribute(bar1)"
rankprofile[2].fef.property[5].name "vespa.summary.feature"
rankprofile[2].fef.property[5].value "attribute(foo2).out"
rankprofile[2].fef.property[6].name "vespa.summary.feature"
diff --git a/config-model/src/test/derived/rankexpression/rankexpression.sd b/config-model/src/test/derived/rankexpression/rankexpression.sd
index 8ed1f2bab4c..d3e0057cfe1 100644
--- a/config-model/src/test/derived/rankexpression/rankexpression.sd
+++ b/config-model/src/test/derived/rankexpression/rankexpression.sd
@@ -5,12 +5,10 @@ search rankexpression {
field artist type string {
indexing: summary | index
- # index-to: artist, default
}
field title type string {
indexing: summary | index
- # index-to: title, default
}
field surl type string {
@@ -21,6 +19,38 @@ search rankexpression {
indexing: summary | attribute
}
+ field foo1 type int {
+ indexing: attribute
+ }
+
+ field foo2 type int {
+ indexing: attribute
+ }
+
+ field foo3 type int {
+ indexing: attribute
+ }
+
+ field foo4 type int {
+ indexing: attribute
+ }
+
+ field bar1 type int {
+ indexing: attribute
+ }
+
+ field bar2 type int {
+ indexing: attribute
+ }
+
+ field bar3 type int {
+ indexing: attribute
+ }
+
+ field bar4 type int {
+ indexing: attribute
+ }
+
}
rank-profile default {
@@ -33,7 +63,7 @@ search rankexpression {
expression: if(3>2,4,2)
rerank-count: 10
}
- rank-features: attribute(foo1).out attribute(bar1.out)
+ rank-features: attribute(foo1).out attribute(bar1)
rank-features { attribute(foo2).out attribute(bar2).out }
rank-features {
attribute(foo3).out attribute(bar3).out }
@@ -65,7 +95,7 @@ search rankexpression {
file:rankexpression
}
}
- summary-features: attribute(foo1).out attribute(bar1.out)
+ summary-features: attribute(foo1).out attribute(bar1)
summary-features { attribute(foo2).out attribute(bar2).out }
summary-features {
attribute(foo3).out attribute(bar3).out }
diff --git a/config-model/src/test/derived/rankexpression/summary.cfg b/config-model/src/test/derived/rankexpression/summary.cfg
index 00df2e87144..9752a9f55e3 100644
--- a/config-model/src/test/derived/rankexpression/summary.cfg
+++ b/config-model/src/test/derived/rankexpression/summary.cfg
@@ -15,9 +15,25 @@ classes[0].fields[5].name "summaryfeatures"
classes[0].fields[5].type "featuredata"
classes[0].fields[6].name "documentid"
classes[0].fields[6].type "longstring"
-classes[1].id 1787488393
+classes[1].id 1736696699
classes[1].name "attributeprefetch"
classes[1].fields[0].name "year"
+classes[].fields[].type "integer"
+classes[].fields[].name "foo1"
+classes[].fields[].type "integer"
+classes[].fields[].name "foo2"
+classes[].fields[].type "integer"
+classes[].fields[].name "foo3"
+classes[].fields[].type "integer"
+classes[].fields[].name "foo4"
+classes[].fields[].type "integer"
+classes[].fields[].name "bar1"
+classes[].fields[].type "integer"
+classes[].fields[].name "bar2"
+classes[].fields[].type "integer"
+classes[].fields[].name "bar3"
+classes[].fields[].type "integer"
+classes[].fields[].name "bar4"
classes[1].fields[0].type "integer"
classes[1].fields[1].name "rankfeatures"
classes[1].fields[1].type "featuredata"
diff --git a/config-model/src/test/derived/rankexpression/summarymap.cfg b/config-model/src/test/derived/rankexpression/summarymap.cfg
index c810f7282ba..21e6cdf346f 100644
--- a/config-model/src/test/derived/rankexpression/summarymap.cfg
+++ b/config-model/src/test/derived/rankexpression/summarymap.cfg
@@ -7,4 +7,28 @@ override[1].command "rankfeatures"
override[1].arguments ""
override[2].field "summaryfeatures"
override[2].command "summaryfeatures"
-override[2].arguments "" \ No newline at end of file
+override[2].arguments ""
+override[].field "foo1"
+override[].command "attribute"
+override[].arguments "foo1"
+override[].field "foo2"
+override[].command "attribute"
+override[].arguments "foo2"
+override[].field "foo3"
+override[].command "attribute"
+override[].arguments "foo3"
+override[].field "foo4"
+override[].command "attribute"
+override[].arguments "foo4"
+override[].field "bar1"
+override[].command "attribute"
+override[].arguments "bar1"
+override[].field "bar2"
+override[].command "attribute"
+override[].arguments "bar2"
+override[].field "bar3"
+override[].command "attribute"
+override[].arguments "bar3"
+override[].field "bar4"
+override[].command "attribute"
+override[].arguments "bar4" \ No newline at end of file
diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg
index 2b231e0cda2..b6ad5372c05 100644
--- a/config-model/src/test/derived/tensor/rank-profiles.cfg
+++ b/config-model/src/test/derived/tensor/rank-profiles.cfg
@@ -35,7 +35,7 @@ rankprofile[3].name "profile2"
rankprofile[3].fef.property[0].name "vespa.rank.firstphase"
rankprofile[3].fef.property[0].value "rankingExpression(firstphase)"
rankprofile[3].fef.property[1].name "rankingExpression(firstphase).rankingScript"
-rankprofile[3].fef.property[1].value "reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x)"
+rankprofile[3].fef.property[1].value "reduce(reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x), sum)"
rankprofile[3].fef.property[2].name "vespa.type.attribute.f2"
rankprofile[3].fef.property[2].value "tensor(x[2],y[])"
rankprofile[3].fef.property[3].name "vespa.type.attribute.f3"
diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd
index a6a9a98db3a..3d64f6b807e 100644
--- a/config-model/src/test/derived/tensor/tensor.sd
+++ b/config-model/src/test/derived/tensor/tensor.sd
@@ -28,7 +28,7 @@ search tensor {
rank-profile profile2 {
first-phase {
- expression: matmul(attribute(f4), diag(x[2],y[2],z[3]), x)
+ expression: sum(matmul(attribute(f4), diag(x[2],y[2],z[3]), x))
}
}
diff --git a/config-model/src/test/examples/rankpropvars.sd b/config-model/src/test/examples/rankpropvars.sd
index 40f9e73f35a..28959edbc09 100644
--- a/config-model/src/test/examples/rankpropvars.sd
+++ b/config-model/src/test/examples/rankpropvars.sd
@@ -18,8 +18,8 @@ first-phase {
second-phase {
expression {
if (attribute(artist) == query(testvar1),
- 0.0 * fieldMatch(title) + 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist),
- 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title))
+ 0.0 * fieldMatch(title) + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist),
+ 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title))
}
}
@@ -42,8 +42,8 @@ first-phase {
second-phase {
expression {
if (attribute(artist) == query(testvar1),
- 0.0 * fieldMatch(title) + 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist),
- 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title))
+ 0.0 * fieldMatch(title) + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist),
+ 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title))
}
}
}
diff --git a/config-model/src/test/examples/simple.sd b/config-model/src/test/examples/simple.sd
index 4fda7f5039e..96b0fa98098 100644
--- a/config-model/src/test/examples/simple.sd
+++ b/config-model/src/test/examples/simple.sd
@@ -116,7 +116,7 @@ search simple {
first-phase {
keep-rank-count:200
rank-score-drop-limit: -13.0
- expression: attribute(year)
+ expression: attribute(popularity)
}
second-phase {
rerank-count: 99
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java
index 1f60ad870ec..aa01070d296 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java
@@ -18,17 +18,6 @@ import static org.junit.Assert.assertFalse;
public class FeatureNamesTestCase {
@Test
- public void testCanonicalization() {
- assertFalse(FeatureNames.canonicalizeIfValid("foo").isPresent());
- assertEquals("query(bar)", FeatureNames.canonicalize("query(bar)"));
- assertEquals("query(bar)", FeatureNames.canonicalize("query('bar')"));
- assertEquals("constant(bar)", FeatureNames.canonicalize("constant(\"bar\")"));
- assertEquals("query(\"ba.r\")", FeatureNames.canonicalize("query(ba.r)"));
- assertEquals("query(\"ba.r\")", FeatureNames.canonicalize("query('ba.r')"));
- assertEquals("attribute(\"ba.r\")", FeatureNames.canonicalize("attribute(\"ba.r\")"));
- }
-
- @Test
public void testArgument() {
assertFalse(FeatureNames.argumentOf("foo(bar)").isPresent());
assertFalse(FeatureNames.argumentOf("foo(bar.baz)").isPresent());
@@ -42,17 +31,20 @@ public class FeatureNamesTestCase {
@Test
public void testConstantFeature() {
- assertEquals("constant(\"foo/bar\")", FeatureNames.asConstantFeature("foo/bar"));
+ assertEquals("constant(\"foo/bar\")",
+ FeatureNames.asConstantFeature("foo/bar").toString());
}
@Test
public void testAttributeFeature() {
- assertEquals("attribute(foo)", FeatureNames.asAttributeFeature("foo"));
+ assertEquals("attribute(foo)",
+ FeatureNames.asAttributeFeature("foo").toString());
}
@Test
public void testQueryFeature() {
- assertEquals("query(\"foo.bar\")", FeatureNames.asQueryFeature("foo.bar"));
+ assertEquals("query(\"foo.bar\")",
+ FeatureNames.asQueryFeature("foo.bar").toString());
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
index 442c8bd41bd..11093d9f008 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
@@ -135,13 +135,13 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
@Test
public void requireThatConfigIsDerivedForQueryFeatureTypeSettings() throws ParseException {
RankProfileRegistry registry = new RankProfileRegistry();
- SearchBuilder builder = new SearchBuilder(registry);
+ SearchBuilder builder = new SearchBuilder(registry, setupQueryProfileTypes());
builder.importString("search test {\n" +
" document test { } \n" +
" rank-profile p1 {}\n" +
" rank-profile p2 {}\n" +
"}");
- builder.build(new BaseDeployLogger(), setupQueryProfileTypes());
+ builder.build(new BaseDeployLogger());
Search search = builder.getSearch();
assertEquals(4, registry.allRankProfiles().size());
@@ -151,7 +151,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
assertQueryFeatureTypeSettings(registry.getRankProfile(search, "p2"), search);
}
- private static QueryProfiles setupQueryProfileTypes() {
+ private static QueryProfileRegistry setupQueryProfileTypes() {
QueryProfileRegistry registry = new QueryProfileRegistry();
QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry();
QueryProfileType type = new QueryProfileType(new ComponentId("testtype"));
@@ -164,7 +164,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
type.addField(new FieldDescription("ranking.features.query(numeric)",
FieldType.fromString("integer", typeRegistry)), typeRegistry);
typeRegistry.register(type);
- return new QueryProfiles(registry);
+ return registry;
}
private static void assertQueryFeatureTypeSettings(RankProfile profile, Search search) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
index e94880e61c7..82b9f5ac043 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
@@ -207,6 +207,9 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
builder.importString(
"search test {\n" +
" document test { \n" +
+ " field rating_yelp type int {" +
+ " indexing: attribute" +
+ " }" +
" }\n" +
" \n" +
" rank-profile test {\n" +
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
index 5100ac15c40..ed1b00e2875 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
@@ -2,7 +2,10 @@
package com.yahoo.searchdefinition;
import com.yahoo.collections.Pair;
+import com.yahoo.search.query.profile.QueryProfile;
import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.search.query.profile.types.FieldDescription;
+import com.yahoo.search.query.profile.types.QueryProfileType;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
@@ -149,11 +152,12 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
censorBindingHash(testRankProperties.get(4).toString()));
}
-
@Test
public void testNeuralNetworkSetup() throws ParseException {
+ // Note: the type assigned to query profile and constant tensors here is not the correct type
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
- SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
+ QueryProfileRegistry queryProfiles = queryProfileWith("query(q)", "tensor(x[])");
+ SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles);
builder.importString(
"search test {\n" +
" document test { \n" +
@@ -176,13 +180,28 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
" expression: sum(final_layer)\n" +
" }\n" +
" }\n" +
- "\n" +
+ " constant W_hidden {\n" +
+ " type: tensor(x[])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
+ " constant b_input {\n" +
+ " type: tensor(x[])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
+ " constant W_final {\n" +
+ " type: tensor(x[])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
+ " constant b_final {\n" +
+ " type: tensor(x[])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
"}\n");
builder.build();
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry());
+ RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles);
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
- new QueryProfileRegistry(),
+ queryProfiles,
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))",
testRankProperties.get(0).toString());
@@ -198,6 +217,17 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
testRankProperties.get(5).toString());
}
+ private QueryProfileRegistry queryProfileWith(String field, String type) {
+ QueryProfileType queryProfileType = new QueryProfileType("root");
+ queryProfileType.addField(new FieldDescription(field, type));
+ QueryProfileRegistry queryProfileRegistry = new QueryProfileRegistry();
+ queryProfileRegistry.getTypeRegistry().register(queryProfileType);
+ QueryProfile profile = new QueryProfile("default");
+ profile.setType(queryProfileType);
+ queryProfileRegistry.register(profile);
+ return queryProfileRegistry;
+ }
+
private String censorBindingHash(String s) {
StringBuilder b = new StringBuilder();
boolean areInHash = false;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index 800697b3430..0ce6129ef7f 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -38,7 +38,8 @@ class RankProfileSearchFixture {
RankProfileSearchFixture(ApplicationPackage applicationpackage, QueryProfileRegistry queryProfileRegistry,
String rankProfiles, String constant, String field)
throws ParseException {
- SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, new QueryProfileRegistry());
+ this.queryProfileRegistry = queryProfileRegistry;
+ SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, queryProfileRegistry);
String sdContent = "search test {\n" +
" " + (constant != null ? constant : "") + "\n" +
" document test {\n" +
@@ -50,7 +51,6 @@ class RankProfileSearchFixture {
builder.importString(sdContent);
builder.build();
search = builder.getSearch();
- this.queryProfileRegistry = queryProfileRegistry;
}
public void assertFirstPhaseExpression(String expExpression, String rankProfile) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java
new file mode 100644
index 00000000000..5f5b40e545f
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java
@@ -0,0 +1,239 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.SearchBuilder;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.yolean.Exceptions;
+import org.junit.Test;
+
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import static com.yahoo.config.model.test.TestUtil.joinLines;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/**
+ * @author bratseth
+ */
+public class RankingExpressionTypeValidatorTestCase {
+
+ @Test
+ public void tensorFirstPhaseMustProduceDouble() throws Exception {
+ try {
+ SearchBuilder builder = new SearchBuilder();
+ builder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: attribute(a)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build();
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void tensorSecondPhaseMustProduceDouble() throws Exception {
+ try {
+ SearchBuilder builder = new SearchBuilder();
+ builder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: sum(attribute(a))",
+ " }",
+ " second-phase {",
+ " expression: attribute(a)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build();
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("In search definition 'test', rank profile 'my_rank_profile': The second-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void tensorConditionsMustHaveTypeCompatibleBranches() throws Exception {
+ try {
+ SearchBuilder searchBuilder = new SearchBuilder();
+ searchBuilder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " field b type tensor(z[10]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: sum(if(1>0, attribute(a), attribute(b)))",
+ " }",
+ " }",
+ "}"
+ ));
+ searchBuilder.build();
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression is invalid: An if expression must produce compatible types in both alternatives, but the 'true' type is tensor(x[],y[]) while the 'false' type is tensor(z[10])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void testMacroInvocationTypes() throws Exception {
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
+ builder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " field b type tensor(z[10]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " macro macro1(attribute_to_use) {",
+ " expression: attribute(attribute_to_use)",
+ " }",
+ " summary-features {",
+ " macro1(a)",
+ " macro1(b)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build();
+ RankProfile profile =
+ builder.getRankProfileRegistry().getRankProfile(builder.getSearch(), "my_rank_profile");
+ assertEquals(TensorType.fromSpec("tensor(x[],y[])"),
+ summaryFeatures(profile).get("macro1(a)").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ assertEquals(TensorType.fromSpec("tensor(z[10])"),
+ summaryFeatures(profile).get("macro1(b)").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ }
+
+ @Test
+ public void testTensorMacroInvocationTypes_Nested() throws Exception {
+ SearchBuilder builder = new SearchBuilder();
+ builder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " field b type tensor(z[10]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " macro return_a() {",
+ " expression: return_first(attribute(a), attribute(b))",
+ " }",
+ " macro return_b() {",
+ " expression: return_second(attribute(a), attribute(b))",
+ " }",
+ " macro return_first(e1, e2) {",
+ " expression: e1",
+ " }",
+ " macro return_second(e1, e2) {",
+ " expression: return_first(e2, e1)",
+ " }",
+ " summary-features {",
+ " return_a",
+ " return_b",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build();
+ RankProfile profile =
+ builder.getRankProfileRegistry().getRankProfile(builder.getSearch(), "my_rank_profile");
+ assertEquals(TensorType.fromSpec("tensor(x[],y[])"),
+ summaryFeatures(profile).get("return_a").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ assertEquals(TensorType.fromSpec("tensor(z[10])"),
+ summaryFeatures(profile).get("return_b").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ }
+
+ @Test
+ public void importedFieldsAreAvailable() throws Exception {
+ SearchBuilder builder = new SearchBuilder();
+ builder.importString(joinLines(
+ "search parent {",
+ " document parent {",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.importString(joinLines(
+ "search child {",
+ " document child { ",
+ " field ref type reference<parent> {",
+ "indexing: attribute | summary",
+ " }",
+ " }",
+ " import field ref.a as imported_a {}",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: sum(attribute(imported_a))",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build();
+ }
+
+ @Test
+ public void undeclaredQueryFeaturesAreAccepted() throws Exception {
+ SearchBuilder builder = new SearchBuilder();
+ builder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: query(foo)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build();
+ }
+
+ private Map<String, ReferenceNode> summaryFeatures(RankProfile profile) {
+ return profile.getSummaryFeatures().stream().collect(Collectors.toMap(f -> f.toString(), f -> f));
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 4693ac5cf4d..1e376824b7b 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -42,7 +42,7 @@ import static org.junit.Assert.*;
public class RankingExpressionWithTensorFlowTestCase {
private final Path applicationDir = Path.fromString("src/test/integration/tensorflow/");
- private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(\"layer_Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"layer_Variable_1\"), d0, d1), f(a,b)(a + b))";
+ private final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"layer_Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"layer_Variable_1_read\"), f(a,b)(a + b))";
@After
public void removeGeneratedConstantTensorFiles() {
@@ -54,8 +54,8 @@ public class RankingExpressionWithTensorFlowTestCase {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -65,15 +65,15 @@ public class RankingExpressionWithTensorFlowTestCase {
"constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
null);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
}
@Test
public void testTensorFlowReferenceWithQueryFeature() {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType = "<query-profile-type id='root'>" +
- " <field name='mytensor' type='tensor(d0[3],d1[784])'/>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
queryProfile,
@@ -85,8 +85,8 @@ public class RankingExpressionWithTensorFlowTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -99,15 +99,15 @@ public class RankingExpressionWithTensorFlowTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
}
@Test
public void testTensorFlowReferenceWithFeatureCombination() {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType = "<query-profile-type id='root'>" +
- " <field name='mytensor' type='tensor(d0[3],d1[784],d2[10])'/>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784],d2[10])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
queryProfile,
@@ -119,8 +119,8 @@ public class RankingExpressionWithTensorFlowTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -128,8 +128,8 @@ public class RankingExpressionWithTensorFlowTestCase {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"5 + sum(tensorflow('mnist_softmax/saved'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -224,8 +224,8 @@ public class RankingExpressionWithTensorFlowTestCase {
"tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("layer_Variable_1", search, Optional.of(10L));
- assertLargeConstant("layer_Variable", search, Optional.of(7840L));
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
@@ -243,8 +243,8 @@ public class RankingExpressionWithTensorFlowTestCase {
searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile");
// Verify that the constants exists, but don't verify the content as we are not
// simulating file distribution in this test
- assertLargeConstant("layer_Variable_1", searchFromStored, Optional.empty());
- assertLargeConstant("layer_Variable", searchFromStored, Optional.empty());
+ assertLargeConstant("layer_Variable_1_read", searchFromStored, Optional.empty());
+ assertLargeConstant("layer_Variable_read", searchFromStored, Optional.empty());
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
@@ -253,7 +253,7 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
- final String expression = "join(rename(reduce(join(join(join(rename(constant(\"dnn_hidden2_Const\"), d0, d1), join(rename(reduce(join(join(join(0.009999999776482582, join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(join(join(0.009999999776482582, join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_outputs_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_outputs_bias\"), d0, d1), f(a,b)(a + b))";
+ final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist/saved')",
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index b001db69768..054c9220225 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -17,98 +17,129 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.vespa.model.container.search.QueryProfiles;
import org.junit.Test;
import java.util.List;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
public class TensorTransformTestCase extends SearchDefinitionTestCase {
@Test
public void requireThatNormalMaxAndMinAreNotReplaced() throws ParseException {
- assertContainsExpression("max(1.0,2.0)", "max(1.0,2.0)");
- assertContainsExpression("min(attribute(double_field),x)", "min(attribute(double_field),x)");
- assertContainsExpression("max(attribute(double_field),attribute(double_array_field))", "max(attribute(double_field),attribute(double_array_field))");
- assertContainsExpression("min(attribute(tensor_field_1),attribute(double_field))", "min(attribute(tensor_field_1),attribute(double_field))");
- assertContainsExpression("max(attribute(tensor_field_1),attribute(tensor_field_2))", "max(attribute(tensor_field_1),attribute(tensor_field_2))");
- assertContainsExpression("min(test_constant_tensor,1.0)", "min(constant(test_constant_tensor),1.0)");
- assertContainsExpression("max(base_constant_tensor,1.0)", "max(constant(base_constant_tensor),1.0)");
- assertContainsExpression("min(constant(file_constant_tensor),1.0)", "min(constant(file_constant_tensor),1.0)");
- assertContainsExpression("max(query(q),1.0)", "max(query(q),1.0)");
- assertContainsExpression("max(query(n),1.0)", "max(query(n),1.0)");
+ assertTransformedExpression("max(1.0,2.0)",
+ "max(1.0,2.0)");
+ assertTransformedExpression("min(attribute(double_field),x)",
+ "min(attribute(double_field),x)");
+ assertTransformedExpression("max(attribute(double_field),attribute(double_array_field))",
+ "max(attribute(double_field),attribute(double_array_field))");
+ assertTransformedExpression("min(attribute(tensor_field_1),attribute(double_field))",
+ "min(attribute(tensor_field_1),attribute(double_field))");
+ assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)",
+ "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)");
+ assertTransformedExpression("min(constant(test_constant_tensor),1.0)",
+ "min(test_constant_tensor,1.0)");
+ assertTransformedExpression("max(constant(base_constant_tensor),1.0)",
+ "max(base_constant_tensor,1.0)");
+ assertTransformedExpression("min(constant(file_constant_tensor),1.0)",
+ "min(constant(file_constant_tensor),1.0)");
+ assertTransformedExpression("max(query(q),1.0)",
+ "max(query(q),1.0)");
+ assertTransformedExpression("max(query(n),1.0)",
+ "max(query(n),1.0)");
}
@Test
public void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException {
- assertContainsExpression("max(attribute(tensor_field_1),x)", "reduce(attribute(tensor_field_1),max,x)");
- assertContainsExpression("1 + max(attribute(tensor_field_1),x)", "1+reduce(attribute(tensor_field_1),max,x)");
- assertContainsExpression("if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)", "if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)");
- assertContainsExpression("max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)");
- assertContainsExpression("max(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),max,x)");
- assertContainsExpression("max(max(attribute(tensor_field_1),x),x)", "max(reduce(attribute(tensor_field_1),max,x),x)"); // will result in deploy error.
- assertContainsExpression("max(max(attribute(tensor_field_2),x),y)", "reduce(reduce(attribute(tensor_field_2),max,x),max,y)");
+ assertTransformedExpression("reduce(attribute(tensor_field_1),max,x)",
+ "max(attribute(tensor_field_1),x)");
+ assertTransformedExpression("1+reduce(attribute(tensor_field_1),max,x)",
+ "1 + max(attribute(tensor_field_1),x)");
+ assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)",
+ "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)");
+ assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)",
+ "max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)");
+ assertTransformedExpression("reduce(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),max,x)",
+ "max(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),x)");
+ assertTransformedExpression("max(reduce(attribute(tensor_field_1),max,x),x)",
+ "max(max(attribute(tensor_field_1),x),x)"); // will result in deploy error.
+ assertTransformedExpression("reduce(reduce(attribute(tensor_field_2),max,x),max,y)",
+ "max(max(attribute(tensor_field_2),x),y)");
}
@Test
public void requireThatMaxAndMinWithConstantTensorsAreReplaced() throws ParseException {
- assertContainsExpression("max(test_constant_tensor,x)", "reduce(constant(test_constant_tensor),max,x)");
- assertContainsExpression("max(base_constant_tensor,x)", "reduce(constant(base_constant_tensor),max,x)");
- assertContainsExpression("min(constant(file_constant_tensor),x)", "reduce(constant(file_constant_tensor),min,x)");
+ assertTransformedExpression("reduce(constant(test_constant_tensor),max,x)",
+ "max(test_constant_tensor,x)");
+ assertTransformedExpression("reduce(constant(base_constant_tensor),max,x)",
+ "max(base_constant_tensor,x)");
+ assertTransformedExpression("reduce(constant(file_constant_tensor),min,x)",
+ "min(constant(file_constant_tensor),x)");
}
@Test
public void requireThatMaxAndMinWithTensorExpressionsAreReplaced() throws ParseException {
- assertContainsExpression("min(attribute(double_field) + attribute(tensor_field_1),x)", "reduce(attribute(double_field)+attribute(tensor_field_1),min,x)");
- assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)");
- assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)");
- assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...)
- assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),min,x)");
+ assertTransformedExpression("reduce(attribute(double_field)+attribute(tensor_field_1),min,x)",
+ "min(attribute(double_field) + attribute(tensor_field_1),x)");
+ assertTransformedExpression("reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)",
+ "min(attribute(tensor_field_1) * attribute(tensor_field_2),x)");
+ assertTransformedExpression("reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)",
+ "min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)");
+ assertTransformedExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)",
+ "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...)
+ assertTransformedExpression("reduce(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),min,x)",
+ "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)");
}
@Test
public void requireThatMaxAndMinWithTensorFromIsReplaced() throws ParseException {
- assertContainsExpression("max(tensorFromLabels(attribute(double_array_field)),double_array_field)", "reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)");
- assertContainsExpression("max(tensorFromLabels(attribute(double_array_field),x),x)", "reduce(tensorFromLabels(attribute(double_array_field),x),max,x)");
- assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)", "reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)");
- assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field),x),x)", "reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)");
+ assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)",
+ "max(tensorFromLabels(attribute(double_array_field)),double_array_field)");
+ assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field),x),max,x)",
+ "max(tensorFromLabels(attribute(double_array_field),x),x)");
+ assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)",
+ "max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)");
+ assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)",
+ "max(tensorFromWeightedSet(attribute(weightedset_field),x),x)");
}
@Test
public void requireThatMaxAndMinWithTensorInQueryIsReplaced() throws ParseException {
- assertContainsExpression("max(query(q),x)", "reduce(query(q),max,x)");
- assertContainsExpression("max(query(n),x)", "max(query(n),x)");
+ assertTransformedExpression("reduce(query(q),max,x)", "max(query(q),x)");
+ assertTransformedExpression("max(query(n),x)", "max(query(n),x)");
}
@Test
public void requireThatMaxAndMinWithTensoresReturnedFromMacrosAreReplaced() throws ParseException {
- assertContainsExpression("max(returns_tensor,x)", "reduce(rankingExpression(returns_tensor),max,x)");
- assertContainsExpression("max(wraps_returns_tensor,x)", "reduce(rankingExpression(wraps_returns_tensor),max,x)");
- assertContainsExpression("max(tensor_inheriting,x)", "reduce(rankingExpression(tensor_inheriting),max,x)");
- assertContainsExpression("max(returns_tensor_with_arg(attribute(tensor_field_1)),x)", "reduce(rankingExpression(returns_tensor_with_arg@),max,x)");
+ assertTransformedExpression("reduce(rankingExpression(returns_tensor),max,x)",
+ "max(returns_tensor,x)");
+ assertTransformedExpression("reduce(rankingExpression(wraps_returns_tensor),max,x)",
+ "max(wraps_returns_tensor,x)");
+ assertTransformedExpression("reduce(rankingExpression(tensor_inheriting),max,x)",
+ "max(tensor_inheriting,x)");
+ assertTransformedExpression("reduce(rankingExpression(returns_tensor_with_arg@),max,x)",
+ "max(returns_tensor_with_arg(attribute(tensor_field_1)),x)");
}
- private void assertContainsExpression(String expr, String transformedExpression) throws ParseException {
- assertTrue("Expected expression '" + transformedExpression + "' found",
- containsExpression(expr, transformedExpression));
- }
-
- private boolean containsExpression(String expr, String transformedExpression) throws ParseException {
- for (Pair<String, String> rankPropertyExpression : buildSearch(expr)) {
+ private void assertTransformedExpression(String expected, String original) throws ParseException {
+ for (Pair<String, String> rankPropertyExpression : buildSearch(original)) {
String rankProperty = rankPropertyExpression.getFirst();
if (rankProperty.equals("rankingExpression(firstphase).rankingScript")) {
String rankExpression = censorBindingHash(rankPropertyExpression.getSecond().replace(" ",""));
- return rankExpression.equals(transformedExpression);
+ assertEquals(expected, rankExpression);
+ return;
}
}
- return false;
+ fail("No 'rankingExpression(firstphase).rankingScript' property produced");
}
private List<Pair<String, String>> buildSearch(String expression) throws ParseException {
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
- SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
+ QueryProfileRegistry queryProfiles = setupQueryProfileTypes();
+ SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles);
builder.importString(
"search test {\n" +
" document test { \n" +
@@ -167,16 +198,16 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
" }\n" +
" }\n" +
"}\n");
- builder.build(new BaseDeployLogger(), setupQueryProfileTypes());
+ builder.build(new BaseDeployLogger());
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry());
+ RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles);
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
- new QueryProfileRegistry(),
+ queryProfiles,
new AttributeFields(s)).configProperties();
return testRankProperties;
}
- private static QueryProfiles setupQueryProfileTypes() {
+ private static QueryProfileRegistry setupQueryProfileTypes() {
QueryProfileRegistry registry = new QueryProfileRegistry();
QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry();
QueryProfileType type = new QueryProfileType(new ComponentId("testtype"));
@@ -185,7 +216,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
type.addField(new FieldDescription("ranking.features.query(n)",
FieldType.fromString("integer", typeRegistry)), typeRegistry);
typeRegistry.register(type);
- return new QueryProfiles(registry);
+ return registry;
}
private String censorBindingHash(String s) {
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/search/DocumentTypeChangeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/search/DocumentTypeChangeValidatorTest.java
index a4ab5ebdb5e..aeddd05209f 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/search/DocumentTypeChangeValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/search/DocumentTypeChangeValidatorTest.java
@@ -25,7 +25,6 @@ import static org.junit.Assert.assertTrue;
* Test validation of changes between a current and next document type used in a document database.
*
* @author toregge
- * @since 2014-11-25
*/
public class DocumentTypeChangeValidatorTest {
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
index 9bf2a858476..d3edd1c0ca5 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
@@ -737,7 +737,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
// TODO: Propagate all filters
Optional<Hostname> hostname = Optional.ofNullable(request.getProperty("hostname")).map(Hostname::new);
-
+ controller.applications().restart(deploymentId, hostname);
// TODO: Change to return JSON
return new StringResponse("Requested restart of " + path(TenantResource.API_PATH, tenantName,
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java
index 0e703cf4cec..5be7fe03319 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java
@@ -8,6 +8,7 @@ import com.yahoo.jdisc.handler.ResponseHandler;
import com.yahoo.jdisc.http.HttpRequest.Method;
import com.yahoo.jdisc.http.filter.DiscFilterRequest;
import com.yahoo.jdisc.http.filter.SecurityRequestFilter;
+import com.yahoo.log.LogLevel;
import com.yahoo.vespa.athenz.api.AthenzDomain;
import com.yahoo.vespa.athenz.api.AthenzIdentity;
import com.yahoo.vespa.athenz.api.AthenzPrincipal;
@@ -30,6 +31,7 @@ import javax.ws.rs.WebApplicationException;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
+import java.util.logging.Logger;
import static com.yahoo.jdisc.http.HttpRequest.Method.GET;
import static com.yahoo.jdisc.http.HttpRequest.Method.HEAD;
@@ -49,6 +51,8 @@ public class ControllerAuthorizationFilter implements SecurityRequestFilter {
private static final List<Method> WHITELISTED_METHODS = Arrays.asList(GET, OPTIONS, HEAD);
+ private static final Logger log = Logger.getLogger(ControllerAuthorizationFilter.class.getName());
+
private final AthenzClientFactory clientFactory;
private final Controller controller;
private final EntityService entityService;
@@ -261,7 +265,10 @@ public class ControllerAuthorizationFilter implements SecurityRequestFilter {
public void handle(ResponseHandler responseHandler,
DiscFilterRequest request,
WebApplicationException exception) {
- sendErrorResponse(responseHandler, exception.getResponse().getStatus(), exception.getMessage());
+ int statusCode = exception.getResponse().getStatus();
+ String errorMessage = exception.getMessage();
+ log.log(LogLevel.WARNING, String.format("Access denied(%d): %s", statusCode, errorMessage), exception);
+ sendErrorResponse(responseHandler, statusCode, errorMessage);
}
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResource.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResource.java
index f5852b9dfcf..d0154ace4e0 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResource.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResource.java
@@ -4,7 +4,7 @@ package com.yahoo.vespa.hosted.restapi.impl;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.inject.Inject;
import com.yahoo.container.jaxrs.annotation.Component;
-import com.yahoo.vespa.hosted.controller.api.integration.security.KeyService;
+import com.yahoo.container.jdisc.Ckms;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
@@ -24,20 +24,20 @@ import javax.ws.rs.core.UriBuilder;
public class StatusPageResource implements com.yahoo.vespa.hosted.controller.api.statuspage.StatusPageResource {
private final Client client;
- private final KeyService keyService;
+ private final Ckms ckms;
@Inject
- public StatusPageResource(@Component KeyService keyService) {
- this(keyService, ClientBuilder.newClient());
+ public StatusPageResource(@Component Ckms ckms) {
+ this(ckms, ClientBuilder.newClient());
}
- protected StatusPageResource(KeyService keyService, Client client) {
- this.keyService = keyService;
+ protected StatusPageResource(Ckms ckms, Client client) {
+ this.ckms = ckms;
this.client = client;
}
protected UriBuilder statusPageURL(String page, String since) {
- String[] secrets = keyService.getSecret("vespa_hosted.controller.statuspage_api_key").split(":");
+ String[] secrets = ckms.getSecret("vespa_hosted.controller.statuspage_api_key").split(":");
UriBuilder uriBuilder = UriBuilder.fromUri("https://" + secrets[0] + ".statuspage.io/api/v2/" + page + ".json?api_key=" + secrets[1]);
if (since != null) {
uriBuilder.queryParam("since", since);
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResourceTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResourceTest.java
index 4e2e4bb15b4..b116ba3b5ee 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResourceTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResourceTest.java
@@ -3,7 +3,7 @@ package com.yahoo.vespa.hosted.restapi.impl;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
-import com.yahoo.vespa.hosted.controller.api.integration.security.KeyService;
+import com.yahoo.container.jdisc.Ckms;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
@@ -30,15 +30,15 @@ public class StatusPageResourceTest {
Client mockClient = Mockito.mock(Client.class);
WebTarget mockTarget = Mockito.mock(WebTarget.class);
Invocation.Builder mockRequest = Mockito.mock(Invocation.Builder.class);
- KeyService keyService = Mockito.mock(KeyService.class);
+ Ckms ckms = Mockito.mock(Ckms.class);
Mockito.when(mockClient.target(Mockito.any(UriBuilder.class))).thenReturn(mockTarget);
Mockito.when(mockTarget.request()).thenReturn(mockRequest);
Mockito.when(mockRequest.get(JsonNode.class)).thenReturn(
new ObjectMapper().readTree("{\"page\":{\"name\":\"Vespa\"}}"));
- Mockito.when(keyService.getSecret(Mockito.any(String.class))).thenReturn("testpage:testkey");
+ Mockito.when(ckms.getSecret(Mockito.any(String.class))).thenReturn("testpage:testkey");
- statusPage = new StatusPageResource(keyService, mockClient);
+ statusPage = new StatusPageResource(ckms, mockClient);
}
diff --git a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerImpl.java b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerImpl.java
index e81c6325922..2da18e12e40 100644
--- a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerImpl.java
+++ b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerImpl.java
@@ -15,15 +15,14 @@ import com.github.dockerjava.api.model.Image;
import com.github.dockerjava.api.model.Network;
import com.github.dockerjava.api.model.Statistics;
import com.github.dockerjava.core.DefaultDockerClientConfig;
+import com.github.dockerjava.core.DockerClientConfig;
import com.github.dockerjava.core.DockerClientImpl;
-import com.github.dockerjava.core.RemoteApiVersion;
import com.github.dockerjava.core.async.ResultCallbackTemplate;
import com.github.dockerjava.core.command.BuildImageResultCallback;
import com.github.dockerjava.core.command.ExecStartResultCallback;
import com.github.dockerjava.core.command.PullImageResultCallback;
import com.github.dockerjava.jaxrs.JerseyDockerCmdExecFactory;
import com.google.inject.Inject;
-import com.yahoo.log.LogLevel;
import com.yahoo.vespa.hosted.dockerapi.metrics.CounterWrapper;
import com.yahoo.vespa.hosted.dockerapi.metrics.Dimensions;
import com.yahoo.vespa.hosted.dockerapi.metrics.MetricReceiverWrapper;
@@ -34,7 +33,6 @@ import java.io.File;
import java.io.IOException;
import java.net.Inet6Address;
import java.net.InetAddress;
-import java.net.URI;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
@@ -58,13 +56,11 @@ public class DockerImpl implements Docker {
public static final String DOCKER_CUSTOM_MACVLAN_NETWORK_NAME = "vespa-macvlan";
static final String LABEL_NAME_MANAGEDBY = "com.yahoo.vespa.managedby";
-
- private final int SECONDS_TO_WAIT_BEFORE_KILLING;
- private final boolean fallbackTo123OnErrors;
private static final String FRAMEWORK_CONTAINER_PREFIX = "/";
+
private final DockerConfig config;
- private final boolean inProduction;
- private Optional<DockerImageGarbageCollector> dockerImageGC = Optional.empty();
+ private final Optional<DockerImageGarbageCollector> dockerImageGC;
+ private final int secondsToWaitBeforeKilling;
private CounterWrapper numberOfDockerDaemonFails;
private boolean started = false;
@@ -76,63 +72,40 @@ public class DockerImpl implements Docker {
DockerClient dockerClient;
@Inject
- public DockerImpl(final DockerConfig config, MetricReceiverWrapper metricReceiver) {
- this(config,
- true, /* fallback to 1.23 on errors */
- metricReceiver,
- !config.isRunningLocally());
- }
-
- private DockerImpl(final DockerConfig config,
- boolean fallbackTo123OnErrors,
- MetricReceiverWrapper metricReceiverWrapper,
- boolean inProduction) {
+ public DockerImpl(DockerConfig config, MetricReceiverWrapper metricReceiverWrapper) {
this.config = config;
- this.fallbackTo123OnErrors = fallbackTo123OnErrors;
- this.inProduction = inProduction;
- if (config == null) {
- this.SECONDS_TO_WAIT_BEFORE_KILLING = 10;
- } else {
- SECONDS_TO_WAIT_BEFORE_KILLING = config.secondsToWaitBeforeKillingContainer();
- }
- if (metricReceiverWrapper != null) {
- setMetrics(metricReceiverWrapper);
- }
+
+ secondsToWaitBeforeKilling = Optional.ofNullable(config)
+ .map(DockerConfig::secondsToWaitBeforeKillingContainer)
+ .orElse(10);
+
+ dockerImageGC = Optional.ofNullable(config)
+ .map(DockerConfig::imageGCMinTimeToLiveMinutes)
+ .map(Duration::ofMinutes)
+ .map(DockerImageGarbageCollector::new);
+
+ Optional.ofNullable(metricReceiverWrapper).ifPresent(this::setMetrics);
}
// For testing
DockerImpl(final DockerClient dockerClient) {
- this(null, false, null, false);
+ this(null, null);
this.dockerClient = dockerClient;
}
- // For testing
- DockerImpl(final DockerConfig config,
- boolean fallbackTo123OnErrors,
- MetricReceiverWrapper metricReceiverWrapper) {
- this(config, fallbackTo123OnErrors, metricReceiverWrapper, false);
- }
-
@Override
public void start() {
if (started) return;
started = true;
if (config != null) {
- if (dockerClient == null) {
- dockerClient = initDockerConnection();
- }
- if (inProduction) {
- Duration minAgeToDelete = Duration.ofMinutes(config.imageGCMinTimeToLiveMinutes());
- dockerImageGC = Optional.of(new DockerImageGarbageCollector(minAgeToDelete));
-
+ dockerClient = createDockerClient(config);
- if (!config.networkNATed()) {
- try {
- setupDockerNetworkIfNeeded();
- } catch (Exception e) {
- throw new DockerException("Could not setup docker network", e);
- }
+ if (!config.networkNATed()) {
+ try {
+ setupDockerNetworkIfNeeded();
+ } catch (Exception e) {
+ throw new DockerException("Could not setup docker network", e);
}
}
}
@@ -143,21 +116,6 @@ public class DockerImpl implements Docker {
return config.networkNATed();
}
- static DefaultDockerClientConfig.Builder buildDockerClientConfig(DockerConfig config) {
- DefaultDockerClientConfig.Builder dockerConfigBuilder = new DefaultDockerClientConfig.Builder()
- .withDockerHost(config.uri());
-
- if (URI.create(config.uri()).getScheme().equals("tcp") && !config.caCertPath().isEmpty()) {
- // In current version of docker-java (3.0.2), withDockerTlsVerify() only effect is when using it together
- // with withDockerCertPath(), where setting withDockerTlsVerify() must be set to true, otherwise the
- // cert path parameter will be ignored.
- // withDockerTlsVerify() has no effect when used with withCustomSslConfig()
- dockerConfigBuilder.withCustomSslConfig(new VespaSSLConfig(config));
- }
-
- return dockerConfigBuilder;
- }
-
private void setupDockerNetworkIfNeeded() throws IOException {
if (!dockerClient.listNetworksCmd().withNameFilter(DOCKER_CUSTOM_MACVLAN_NETWORK_NAME).exec().isEmpty()) return;
@@ -366,7 +324,7 @@ public class DockerImpl implements Docker {
@Override
public void stopContainer(final ContainerName containerName) {
try {
- dockerClient.stopContainerCmd(containerName.asString()).withTimeout(SECONDS_TO_WAIT_BEFORE_KILLING).exec();
+ dockerClient.stopContainerCmd(containerName.asString()).withTimeout(secondsToWaitBeforeKilling).exec();
} catch (NotModifiedException ignored) {
// If is already stopped, ignore
} catch (RuntimeException e) {
@@ -545,36 +503,18 @@ public class DockerImpl implements Docker {
}
}
- private DockerClient initDockerConnection() {
+ private static DockerClient createDockerClient(DockerConfig config) {
JerseyDockerCmdExecFactory dockerFactory = new JerseyDockerCmdExecFactory()
.withMaxPerRouteConnections(config.maxPerRouteConnections())
.withMaxTotalConnections(config.maxTotalConnections())
.withConnectTimeout(config.connectTimeoutMillis())
.withReadTimeout(config.readTimeoutMillis());
- RemoteApiVersion remoteApiVersion;
- try {
- remoteApiVersion = RemoteApiVersion.parseConfig(DockerClientImpl.getInstance(
- buildDockerClientConfig(config).build())
- .withDockerCmdExecFactory(dockerFactory).versionCmd().exec().getApiVersion());
- logger.info("Found version of remote docker API: " + remoteApiVersion);
- // From version 1.24 a field was removed which causes trouble with the current docker java code.
- // When this is fixed, we can remove this and do not specify version.
- if (remoteApiVersion.isGreaterOrEqual(RemoteApiVersion.VERSION_1_24)) {
- remoteApiVersion = RemoteApiVersion.VERSION_1_23;
- logger.info("Found version 1.24 or newer of remote API, using 1.23.");
- }
- } catch (Exception e) {
- if (!fallbackTo123OnErrors) {
- throw e;
- }
- logger.log(LogLevel.ERROR, "Failed when trying to figure out remote API version of docker, using 1.23", e);
- remoteApiVersion = RemoteApiVersion.VERSION_1_23;
- }
- return DockerClientImpl.getInstance(
- buildDockerClientConfig(config)
- .withApiVersion(remoteApiVersion)
- .build())
+ DockerClientConfig dockerClientConfig = new DefaultDockerClientConfig.Builder()
+ .withDockerHost(config.uri())
+ .build();
+
+ return DockerClientImpl.getInstance(dockerClientConfig)
.withDockerCmdExecFactory(dockerFactory);
}
diff --git a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerTestUtils.java b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerTestUtils.java
deleted file mode 100644
index 549af0d85cb..00000000000
--- a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerTestUtils.java
+++ /dev/null
@@ -1,94 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.dockerapi;
-
-import com.github.dockerjava.api.model.Network;
-import com.yahoo.metrics.simple.MetricReceiver;
-import com.yahoo.vespa.hosted.dockerapi.metrics.MetricReceiverWrapper;
-
-import java.io.File;
-
-/**
- * Helper class for testing full integration with docker daemon, requires running daemon. To run these tests:
- *
- * MAC:
- * 1. Install Docker Toolbox, and start it (Docker Quickstart Terminal) (you can close terminal window afterwards)
- * 2. For network test, we need to make docker containers visible for Mac: sudo route add 172.18.0.0/16 192.168.99.100
- *
- * @author freva
- */
-public class DockerTestUtils {
- private static final OS operatingSystem = getSystemOS();
- private static final String prefix = "/Users/" + System.getProperty("user.name") + "/.docker/machine/machines/default/";
- private static final DockerConfig dockerConfig = new DockerConfig(new DockerConfig.Builder()
- .caCertPath( operatingSystem == OS.Mac_OS_X ? prefix + "ca.pem" : "")
- .clientCertPath(operatingSystem == OS.Mac_OS_X ? prefix + "cert.pem" : "")
- .clientKeyPath( operatingSystem == OS.Mac_OS_X ? prefix + "key.pem" : "")
- .uri( operatingSystem == OS.Mac_OS_X ? "tcp://192.168.99.100:2376" : "tcp://localhost:2376")
- .secondsToWaitBeforeKillingContainer(0));
- private static DockerImpl docker;
-
- public static boolean dockerDaemonIsPresent() {
- if (docker != null) return true;
- if (operatingSystem == OS.Unsupported) {
- System.err.println("This test does not support " + System.getProperty("os.name") + " yet, ignoring test.");
- return false;
- }
-
- try {
- getDocker(); // Will throw an exception if docker is not installed/incorrectly configured
- return true;
- } catch (Exception e) {
- System.err.println("Please install Docker Toolbox and start Docker Quick Start Terminal once, ignoring test.");
- System.err.println(e.getMessage());
- return false;
- }
- }
-
- public static DockerImpl getDocker() {
- if (docker == null) {
- DockerImpl tmpDocker = new DockerImpl(
- dockerConfig,
- false, /* fallback to 1.23 on errors */
- new MetricReceiverWrapper(MetricReceiver.nullImplementation));
- tmpDocker.start();
- createDockerTestNetworkIfNeeded(tmpDocker);
- docker = tmpDocker;
- }
-
- return docker;
- }
-
- public static void createDockerTestNetworkIfNeeded(DockerImpl docker) {
- if (! docker.dockerClient.listNetworksCmd().withNameFilter(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME).exec().isEmpty()) return;
-
- Network.Ipam ipam = new Network.Ipam().withConfig(new Network.Ipam.Config()
- .withSubnet("172.18.0.0/16")
- .withGateway("172.18.0.1"));
- docker.dockerClient.createNetworkCmd()
- .withName(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME).withDriver("bridge").withIpam(ipam).exec();
- }
-
- public static void buildSimpleHttpServerDockerImage(DockerImpl docker, DockerImage dockerImage) {
- try {
- docker.deleteImage(dockerImage);
- } catch (Exception e) {
- if (! e.getMessage().equals("Failed to delete docker image " + dockerImage.asString())) {
- throw e;
- }
- }
-
- // Build the image locally
- File dockerFileStream = new File("src/test/resources/simple-ipv6-server");
- docker.buildImage(dockerFileStream, dockerImage);
- }
-
- public enum OS { Linux, Mac_OS_X, Unsupported }
-
- public static OS getSystemOS() {
- switch (System.getProperty("os.name").toLowerCase()) {
- case "linux": return OS.Linux;
- case "mac os x": return OS.Mac_OS_X;
- default: return OS.Unsupported;
- }
- }
-}
diff --git a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/VespaSSLConfig.java b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/VespaSSLConfig.java
deleted file mode 100644
index e9bc0181dd7..00000000000
--- a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/VespaSSLConfig.java
+++ /dev/null
@@ -1,228 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.dockerapi;
-
-import com.github.dockerjava.api.exception.DockerClientException;
-import com.github.dockerjava.core.SSLConfig;
-import org.bouncycastle.asn1.ASN1ObjectIdentifier;
-import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
-import org.bouncycastle.cert.X509CertificateHolder;
-import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
-import org.bouncycastle.jce.provider.BouncyCastleProvider;
-import org.bouncycastle.openssl.PEMKeyPair;
-import org.bouncycastle.openssl.PEMParser;
-import org.glassfish.jersey.SslConfigurator;
-
-import javax.net.ssl.SSLContext;
-import java.io.BufferedReader;
-import java.io.IOException;
-import java.io.Reader;
-import java.io.StringReader;
-import java.nio.file.Files;
-import java.nio.file.Paths;
-import java.security.KeyFactory;
-import java.security.KeyStore;
-import java.security.KeyStoreException;
-import java.security.NoSuchAlgorithmException;
-import java.security.PrivateKey;
-import java.security.Security;
-import java.security.cert.Certificate;
-import java.security.cert.CertificateException;
-import java.security.spec.InvalidKeySpecException;
-import java.security.spec.PKCS8EncodedKeySpec;
-import java.util.ArrayList;
-import java.util.List;
-
-import static java.util.Objects.requireNonNull;
-
-
-/**
- * This class is based off {@link com.github.dockerjava.core.LocalDirectorySSLConfig}, but with the ability to
- * specify path to each of the certificates instead of directory path. Additionally it includes
- * {@link com.github.dockerjava.core.util.CertificateUtils} because of version conflict of with
- * com.google.code.findbugs.annotations
- */
-public class VespaSSLConfig implements SSLConfig {
- private final DockerConfig config;
-
- public VespaSSLConfig(DockerConfig config) {
- this.config = config;
- }
-
- @Override
- public SSLContext getSSLContext() {
- try {
- Security.addProvider(new BouncyCastleProvider());
-
- // properties acrobatics not needed for java > 1.6
- String httpProtocols = System.getProperty("https.protocols");
- System.setProperty("https.protocols", "TLSv1");
- SslConfigurator sslConfig = SslConfigurator.newInstance(true);
- if (httpProtocols != null) {
- System.setProperty("https.protocols", httpProtocols);
- }
-
- String keypem = new String(Files.readAllBytes(Paths.get(config.clientKeyPath())));
- String certpem = new String(Files.readAllBytes(Paths.get(config.clientCertPath())));
- String capem = new String(Files.readAllBytes(Paths.get(config.caCertPath())));
-
- sslConfig.keyStore(createKeyStore(keypem, certpem));
- sslConfig.keyStorePassword("docker");
- sslConfig.trustStore(createTrustStore(capem));
-
- return sslConfig.createSSLContext();
- } catch (Exception e) {
- throw new DockerClientException(e.getMessage(), e);
- }
- }
-
- public static KeyStore createKeyStore(final String keypem, final String certpem) throws NoSuchAlgorithmException,
- IOException, CertificateException, KeyStoreException {
- PrivateKey privateKey = loadPrivateKey(keypem);
- requireNonNull(privateKey);
- List<Certificate> privateCertificates = loadCertificates(certpem);
-
- KeyStore keyStore = KeyStore.getInstance("JKS");
- keyStore.load(null);
-
- keyStore.setKeyEntry("docker",
- privateKey,
- "docker".toCharArray(),
- privateCertificates.toArray(new Certificate[privateCertificates.size()])
- );
-
- return keyStore;
- }
-
- /**
- * from "cert.pem" String
- */
- private static List<Certificate> loadCertificates(final String certpem) throws IOException,
- CertificateException {
- final StringReader certReader = new StringReader(certpem);
- try (BufferedReader reader = new BufferedReader(certReader)) {
- return loadCertificates(reader);
- }
- }
-
- /**
- * "cert.pem" from reader
- */
- private static List<Certificate> loadCertificates(final Reader reader) throws IOException,
- CertificateException {
- try (PEMParser pemParser = new PEMParser(reader)) {
- List<Certificate> certificates = new ArrayList<>();
-
- JcaX509CertificateConverter certificateConverter = new JcaX509CertificateConverter().setProvider("BC");
- Object certObj = pemParser.readObject();
-
- if (certObj instanceof X509CertificateHolder) {
- X509CertificateHolder certificateHolder = (X509CertificateHolder) certObj;
- certificates.add(certificateConverter.getCertificate(certificateHolder));
- }
-
- return certificates;
- }
- }
-
-
- /**
- * Return private key ("key.pem") from Reader
- */
- private static PrivateKey loadPrivateKey(final Reader reader) throws IOException, NoSuchAlgorithmException {
- try (PEMParser pemParser = new PEMParser(reader)) {
- Object readObject = pemParser.readObject();
- while (readObject != null) {
- if (readObject instanceof PEMKeyPair) {
- PEMKeyPair pemKeyPair = (PEMKeyPair) readObject;
- PrivateKey privateKey = guessKey(pemKeyPair.getPrivateKeyInfo().getEncoded());
- if (privateKey != null) {
- return privateKey;
- }
- } else if (readObject instanceof PrivateKeyInfo) {
- PrivateKeyInfo privateKeyInfo = (PrivateKeyInfo) readObject;
- PrivateKey privateKey = guessKey(privateKeyInfo.getEncoded());
- if (privateKey != null) {
- return privateKey;
- }
- } else if (readObject instanceof ASN1ObjectIdentifier) {
- // no idea how it can be used
- final ASN1ObjectIdentifier asn1ObjectIdentifier = (ASN1ObjectIdentifier) readObject;
- }
-
- readObject = pemParser.readObject();
- }
- }
-
- return null;
- }
-
- private static PrivateKey guessKey(byte[] encodedKey) throws NoSuchAlgorithmException {
- //no way to know, so iterate
- for (String guessFactory : new String[]{"RSA", "ECDSA"}) {
- try {
- KeyFactory factory = KeyFactory.getInstance(guessFactory);
-
- PKCS8EncodedKeySpec privateKeySpec = new PKCS8EncodedKeySpec(encodedKey);
- return factory.generatePrivate(privateKeySpec);
- } catch (InvalidKeySpecException ignore) {
- }
- }
-
- return null;
- }
-
- /**
- * Return KeyPair from "key.pem"
- */
- private static PrivateKey loadPrivateKey(final String keypem) throws IOException, NoSuchAlgorithmException {
- try (StringReader certReader = new StringReader(keypem);
- BufferedReader reader = new BufferedReader(certReader)) {
- return loadPrivateKey(reader);
- }
- }
-
- /**
- * "ca.pem" from String
- */
- public static KeyStore createTrustStore(String capem) throws IOException, CertificateException,
- KeyStoreException, NoSuchAlgorithmException {
- try (Reader certReader = new StringReader(capem)) {
- return createTrustStore(certReader);
- }
- }
-
- /**
- * "ca.pem" from Reader
- */
- public static KeyStore createTrustStore(final Reader certReader) throws IOException, CertificateException,
- KeyStoreException, NoSuchAlgorithmException {
- try (PEMParser pemParser = new PEMParser(certReader)) {
- X509CertificateHolder certificateHolder = (X509CertificateHolder) pemParser.readObject();
- Certificate caCertificate = new JcaX509CertificateConverter()
- .setProvider("BC")
- .getCertificate(certificateHolder);
-
- KeyStore trustStore = KeyStore.getInstance("JKS");
- trustStore.load(null);
- trustStore.setCertificateEntry("ca", caCertificate);
-
- return trustStore;
- }
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
-
- VespaSSLConfig that = (VespaSSLConfig) o;
-
- return config.equals(that.config);
-
- }
-
- @Override
- public int hashCode() {
- return config.hashCode();
- }
-}
diff --git a/docker-api/src/main/resources/configdefinitions/docker.def b/docker-api/src/main/resources/configdefinitions/docker.def
index b4585318cd8..83fee05dff6 100644
--- a/docker-api/src/main/resources/configdefinitions/docker.def
+++ b/docker-api/src/main/resources/configdefinitions/docker.def
@@ -1,9 +1,6 @@
# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
namespace=vespa.hosted.dockerapi
-caCertPath string default = ""
-clientCertPath string default = ""
-clientKeyPath string default = ""
uri string default = "unix:///host/var/run/docker.sock"
secondsToWaitBeforeKillingContainer int default = 10
diff --git a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerImplTest.java b/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerImplTest.java
index 12e52dde494..654b5df3f3b 100644
--- a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerImplTest.java
+++ b/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerImplTest.java
@@ -12,7 +12,6 @@ import com.github.dockerjava.api.command.InspectImageCmd;
import com.github.dockerjava.api.command.InspectImageResponse;
import com.github.dockerjava.api.command.PullImageCmd;
import com.github.dockerjava.api.exception.NotFoundException;
-import com.github.dockerjava.core.DefaultDockerClientConfig;
import com.github.dockerjava.core.command.ExecStartResultCallback;
import com.yahoo.metrics.simple.MetricReceiver;
import com.yahoo.vespa.hosted.dockerapi.metrics.MetricReceiverWrapper;
@@ -20,12 +19,6 @@ import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Matchers;
-import java.io.IOException;
-import java.security.KeyManagementException;
-import java.security.KeyStoreException;
-import java.security.NoSuchAlgorithmException;
-import java.security.UnrecoverableKeyException;
-
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
@@ -39,60 +32,6 @@ import static org.mockito.Mockito.when;
* @author tonytv
*/
public class DockerImplTest {
- @Test
- public void testDockerConfigWithUnixPath() throws UnrecoverableKeyException, NoSuchAlgorithmException, KeyStoreException, KeyManagementException {
- String dockerUri = "unix:///var/run/docker.sock";
- DockerConfig config = createConfig(dockerUri, null, null, null);
- DefaultDockerClientConfig clientConfig = DockerImpl.buildDockerClientConfig(config).build();
-
- assertTrue("Docker uri incorrectly set", clientConfig.getDockerHost().toString().equals(dockerUri));
- assertTrue("SSL config was set when using socket", clientConfig.getSSLConfig() == null);
- }
-
- @Test
- public void testDockerConfigWithTcpPathWithoutSSL() {
- String dockerUri = "tcp://127.0.0.1:2376";
- DockerConfig config = createConfig(dockerUri, null, null, null);
- DefaultDockerClientConfig clientConfig = DockerImpl.buildDockerClientConfig(config).build();
-
- assertTrue("Docker uri incorrectly set", clientConfig.getDockerHost().toString().equals(dockerUri));
- assertTrue("SSL config was set", clientConfig.getSSLConfig() == null);
- }
-
- @Test
- public void testDockerConfigWithTcpPathWithSslConfig() throws IOException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyStoreException, KeyManagementException {
- String dockerUri = "tcp://127.0.0.1:2376";
- DockerConfig config = createConfig(dockerUri, "/some/path/ca", "/some/path/cert", "/some/path/key");
- DefaultDockerClientConfig clientConfig = DockerImpl.buildDockerClientConfig(config).build();
-
- assertTrue("Docker uri incorrectly set", clientConfig.getDockerHost().toString().equals(dockerUri));
- assertTrue("SSL config was not set", clientConfig.getSSLConfig() != null);
- }
-
- @Test(expected=RuntimeException.class)
- public void testDockerConfigWithTcpPathWithInvalidSslConfig() throws IOException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyStoreException, KeyManagementException {
- String dockerUri = "tcp://127.0.0.1:2376";
- DockerConfig config = createConfig(dockerUri, "/some/path/ca", "/some/path/cert", "/some/path/key");
- DefaultDockerClientConfig clientConfig = DockerImpl.buildDockerClientConfig(config).build();
-
- assertTrue("Docker uri incorrectly set", clientConfig.getDockerHost().toString().equals(dockerUri));
- assertTrue("SSL config was not set", clientConfig.getSSLConfig() != null);
-
- // SSL certificates are read during the getSSLContext(), the invalid paths should cause a RuntimeException
- clientConfig.getSSLConfig().getSSLContext();
- }
-
- private static DockerConfig createConfig(String uri, String caCertPath, String clientCertPath, String clientKeyPath) {
- DockerConfig.Builder configBuilder = new DockerConfig.Builder();
-
- if (uri != null) configBuilder.uri(uri);
- if (caCertPath != null) configBuilder.caCertPath(caCertPath);
- if (clientCertPath != null) configBuilder.clientCertPath(clientCertPath);
- if (clientKeyPath != null) configBuilder.clientKeyPath(clientKeyPath);
-
- return new DockerConfig(configBuilder);
- }
-
@Test
public void testExecuteCompletes() throws Exception {
diff --git a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerTest.java b/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerTest.java
deleted file mode 100644
index 18f87e5ae17..00000000000
--- a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerTest.java
+++ /dev/null
@@ -1,197 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.dockerapi;
-
-import org.apache.commons.io.IOUtils;
-import org.junit.Before;
-import org.junit.Ignore;
-import org.junit.Test;
-
-import java.io.IOException;
-import java.net.InetAddress;
-import java.net.URL;
-import java.util.Optional;
-import java.util.concurrent.ExecutionException;
-
-import static org.hamcrest.CoreMatchers.is;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertThat;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assume.assumeTrue;
-
-/**
- * Requires docker daemon, see {@link com.yahoo.vespa.hosted.dockerapi.DockerTestUtils} for more details.
- *
- * @author freva
- * @author dybdahl
- */
-public class DockerTest {
- private DockerImpl docker;
- private static final DockerImage dockerImage = new DockerImage("simple-ipv6-server:Dockerfile");
- private static final String MANAGER_NAME = "docker-test";
-
- // Ignored because the test is very slow (several minutes) when swap is enabled, to disable: (Linux)
- // $ sudo swapoff -a
- @Ignore
- @Test
- public void testOutOfMemoryDoesNotAffectOtherContainers() throws InterruptedException, ExecutionException, IOException {
- String hostName1 = "docker10.test.yahoo.com";
- String hostName2 = "docker11.test.yahoo.com";
- ContainerName containerName1 = new ContainerName("docker-test-1");
- ContainerName containerName2 = new ContainerName("docker-test-2");
- InetAddress inetAddress1 = InetAddress.getByName("172.18.10.10");
- InetAddress inetAddress2 = InetAddress.getByName("172.18.10.11");
-
- docker.createContainerCommand(dockerImage, ContainerResources.from(0, 0.1), containerName1, hostName1)
- .withManagedBy(MANAGER_NAME)
- .withNetworkMode(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME)
- .withIpAddress(inetAddress1)
- .create();
- docker.startContainer(containerName1);
-
- docker.createContainerCommand(dockerImage, ContainerResources.from(0, 0.1), containerName2, hostName2)
- .withManagedBy(MANAGER_NAME)
- .withNetworkMode(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME)
- .withIpAddress(inetAddress2)
- .create();
- docker.startContainer(containerName2);
-
- // 137 = 128 + 9 = kill -9 (SIGKILL), doesn't need to be run as "root", but "yahoo" does not exist in this basic image
- assertThat(docker.executeInContainerAsRoot(containerName2, "python", "/pysrc/fillmem.py", "90").getExitStatus(), is(137));
-
- // Verify that both HTTP servers are still up
- testReachabilityFromHost("http://" + inetAddress1.getHostAddress() + "/ping");
- testReachabilityFromHost("http://" + inetAddress2.getHostAddress() + "/ping");
-
- docker.stopContainer(containerName1);
- docker.deleteContainer(containerName1);
-
- docker.stopContainer(containerName2);
- docker.deleteContainer(containerName2);
- }
-
- @Test
- public void testContainerCycle() throws IOException, InterruptedException, ExecutionException {
- final ContainerName containerName = new ContainerName("docker-test-foo");
- final String containerHostname = "hostName1";
-
- docker.createContainerCommand(dockerImage, ContainerResources.UNLIMITED, containerName, containerHostname)
- .withManagedBy(MANAGER_NAME).create();
- Optional<Container> container = docker.getContainer(containerName);
- assertTrue(container.isPresent());
- assertEquals(container.get().state, Container.State.CREATED);
-
- docker.startContainer(containerName);
- container = docker.getContainer(containerName);
- assertTrue(container.isPresent());
- assertEquals(container.get().state, Container.State.RUNNING);
-
- docker.dockerClient.pauseContainerCmd(containerName.asString()).exec();
- container = docker.getContainer(containerName);
- assertTrue(container.isPresent());
- assertEquals(container.get().state, Container.State.PAUSED);
-
- docker.dockerClient.unpauseContainerCmd(containerName.asString()).exec();
- docker.stopContainer(containerName);
- container = docker.getContainer(containerName);
- assertTrue(container.isPresent());
- assertEquals(container.get().state, Container.State.EXITED);
-
- docker.deleteContainer(containerName);
- assertThat(docker.listAllContainersManagedBy(MANAGER_NAME).isEmpty(), is(true));
- }
-
- /**
- * Test the expected behavior for exec when it times out - it should throw an exception when it times out,
- * and before the process completes.
- *
- * The test timeout value is set quite high to avoid noise if screwdriver is slow but lower than the process time.
- */
- @Test(expected = DockerExecTimeoutException.class, timeout = 2000)
- public void testContainerExecHounorsTimeout() throws IOException, InterruptedException, ExecutionException {
- final ContainerName containerName = new ContainerName("docker-test-foo");
- final String containerHostname = "hostName1";
-
- docker.createContainerCommand(dockerImage, ContainerResources.UNLIMITED, containerName, containerHostname)
- .withManagedBy(MANAGER_NAME).create();
- docker.startContainer(containerName);
- docker.executeInContainerAsRoot(containerName, 1L, "sh", "-c", "sleep 5");
- }
-
- /**
- * Test the expected behavior for exec that completes before specified timeout - it should return when the process finishes and not
- * wait for the timeout. Some previous tests indicated that this was not behaving correctly.
- *
- * No timeout implies infinite timeout.
- *
- * The test timeout value is set quite high to avoid noise if screwdriver is slow
- */
- @Test(timeout = 4000)
- public void testContainerExecDoesNotBlockUntilTimeoutWhenCommandFinishesBeforeTimeout() throws IOException, InterruptedException, ExecutionException {
- final ContainerName containerName = new ContainerName("docker-test-foo");
- final String containerHostname = "hostName1";
-
- docker.createContainerCommand(dockerImage, ContainerResources.UNLIMITED, containerName, containerHostname)
- .withManagedBy(MANAGER_NAME).create();
- docker.startContainer(containerName);
- docker.executeInContainerAsRoot(containerName, 2L, "sh", "-c", "echo hei");
-
- // Also test that this is the behavoir when not specifying timeout
- docker.executeInContainerAsRoot(containerName,"sh", "-c", "echo hei");
- }
-
- @Test
- public void testDockerNetworking() throws InterruptedException, ExecutionException, IOException {
- String hostName1 = "docker10.test.yahoo.com";
- String hostName2 = "docker11.test.yahoo.com";
- ContainerName containerName1 = new ContainerName("docker-test-1");
- ContainerName containerName2 = new ContainerName("docker-test-2");
- InetAddress inetAddress1 = InetAddress.getByName("172.18.10.10");
- InetAddress inetAddress2 = InetAddress.getByName("172.18.10.11");
-
- docker.createContainerCommand(dockerImage, ContainerResources.UNLIMITED, containerName1, hostName1)
- .withManagedBy(MANAGER_NAME)
- .withNetworkMode(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME).withIpAddress(inetAddress1).create();
- docker.startContainer(containerName1);
-
- docker.createContainerCommand(dockerImage, ContainerResources.UNLIMITED, containerName2, hostName2)
- .withManagedBy(MANAGER_NAME)
- .withNetworkMode(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME).withIpAddress(inetAddress2).create();
- docker.startContainer(containerName2);
-
- testReachabilityFromHost("http://" + inetAddress1.getHostAddress() + "/ping");
- testReachabilityFromHost("http://" + inetAddress2.getHostAddress() + "/ping");
-
- String[] curlFromNodeToNode = new String[]{"curl", "-g", "http://" + inetAddress2.getHostAddress() + "/ping"};
- ProcessResult result = docker.executeInContainerAsRoot(containerName1, curlFromNodeToNode);
- assertThat("Could not reach " + containerName2.asString() + " from " + containerName1.asString(),
- result.getOutput(), is("pong\n"));
-
- docker.stopContainer(containerName1);
- docker.deleteContainer(containerName1);
-
- docker.stopContainer(containerName2);
- docker.deleteContainer(containerName2);
- }
-
- @Before
- public void setup() throws InterruptedException, ExecutionException, IOException {
- if (docker == null) {
- assumeTrue(DockerTestUtils.dockerDaemonIsPresent());
-
- docker = DockerTestUtils.getDocker();
- DockerTestUtils.buildSimpleHttpServerDockerImage(docker, dockerImage);
- }
-
- // Clean up any non deleted containers from previous tests
- docker.getAllContainersManagedBy(MANAGER_NAME).forEach(container -> {
- if (container.state.isRunning()) docker.stopContainer(container.name);
- docker.deleteContainer(container.name);
- });
- }
-
- private void testReachabilityFromHost(String target) throws IOException, InterruptedException {
- URL url = new URL(target);
- String containerServer = IOUtils.toString(url.openStream());
- assertThat(containerServer, is("pong\n"));
- }
-}
diff --git a/docker-api/src/test/resources/simple-ipv6-server/Dockerfile b/docker-api/src/test/resources/simple-ipv6-server/Dockerfile
deleted file mode 100644
index ee33894dbeb..00000000000
--- a/docker-api/src/test/resources/simple-ipv6-server/Dockerfile
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-FROM gliderlabs/alpine:3.4
-
-# Install python and curl
-RUN apk-install python curl
-
-# Copy source
-ADD src/ pysrc
-
-# Run http server on port 80
-EXPOSE 80
-CMD ["python", "/pysrc/server.py"]
diff --git a/docker-api/src/test/resources/simple-ipv6-server/README b/docker-api/src/test/resources/simple-ipv6-server/README
deleted file mode 100644
index 0cb96035c42..00000000000
--- a/docker-api/src/test/resources/simple-ipv6-server/README
+++ /dev/null
@@ -1,10 +0,0 @@
-Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-This is the source for a basic docker image that runs a python HTTP server listening at IPv6 port 80.
-The server serves two basic paths:
- /ip - returns IP address of the requester
- /ping - returns string "pong"
-
-
-To build the image run:
-$ sudo docker build -t "simple-ipv6-server:Dockerfile" <path to this directory>
diff --git a/docker-api/src/test/resources/simple-ipv6-server/src/fillmem.py b/docker-api/src/test/resources/simple-ipv6-server/src/fillmem.py
deleted file mode 100644
index b3990bea859..00000000000
--- a/docker-api/src/test/resources/simple-ipv6-server/src/fillmem.py
+++ /dev/null
@@ -1,11 +0,0 @@
-# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-import sys
-import time
-
-megabyte = [0] * (1024 * 1024 / 8)
-data = megabyte * int(sys.argv[1])
-
-while True:
- time.sleep(1)
- data.extend(megabyte)
diff --git a/docker-api/src/test/resources/simple-ipv6-server/src/server.py b/docker-api/src/test/resources/simple-ipv6-server/src/server.py
deleted file mode 100644
index 9b4d543d4ed..00000000000
--- a/docker-api/src/test/resources/simple-ipv6-server/src/server.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-import socket
-from BaseHTTPServer import HTTPServer
-from SimpleHTTPServer import SimpleHTTPRequestHandler
-
-
-class MyHandler(SimpleHTTPRequestHandler):
- def do_GET(self):
- if self.path == '/ip':
- self.send_response(200)
- self.send_header('Content-type', 'text/html')
- self.end_headers()
- self.wfile.write('Your IP address is %s\n' % self.client_address[0])
- return
-
- elif self.path == '/ping':
- self.send_response(200)
- self.send_header('Content-type', 'text/html')
- self.end_headers()
- self.wfile.write('pong\n')
- return
-
- else:
- self.send_response(404)
- self.send_header('Content-type', 'text/html')
- self.end_headers()
- self.wfile.write('Could not find ' + self.path + '! Try /ping or /ip.\n')
- return
-
-
-class DualHTTPServer(HTTPServer):
- def __init__(self, address, handler):
- self.address_family = socket.AF_INET6 if (':' in address[0]) else socket.AF_INET
- HTTPServer.__init__(self, address, handler)
-
-
-def main(ipv6):
- server = DualHTTPServer(('::' if ipv6 else '', 80), MyHandler)
- server.serve_forever()
-
-
-if __name__ == '__main__':
- main(False)
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/ConfigServerClientsImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/RealConfigServerClients.java
index 43a2c66a9e5..b8e16ee5910 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/ConfigServerClientsImpl.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/RealConfigServerClients.java
@@ -3,7 +3,7 @@ package com.yahoo.vespa.hosted.node.admin.configserver;
import com.yahoo.vespa.hosted.node.admin.component.Environment;
import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeRepository;
-import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeRepositoryImpl;
+import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.RealNodeRepository;
import com.yahoo.vespa.hosted.node.admin.configserver.orchestrator.Orchestrator;
import com.yahoo.vespa.hosted.node.admin.configserver.orchestrator.OrchestratorImpl;
@@ -12,25 +12,25 @@ import java.util.Optional;
/**
* @author freva
*/
-public class ConfigServerClientsImpl implements ConfigServerClients {
+public class RealConfigServerClients implements ConfigServerClients {
private final Optional<ConfigServerApi> configServerApi;
private final NodeRepository nodeRepository;
private final Orchestrator orchestrator;
- public ConfigServerClientsImpl(Environment environment) {
+ public RealConfigServerClients(Environment environment) {
this(new SslConfigServerApiImpl(environment));
}
- public ConfigServerClientsImpl(NodeRepository nodeRepository, Orchestrator orchestrator) {
+ public RealConfigServerClients(NodeRepository nodeRepository, Orchestrator orchestrator) {
this(nodeRepository, orchestrator, Optional.empty());
}
- private ConfigServerClientsImpl(ConfigServerApi configServerApi) {
- this(new NodeRepositoryImpl(configServerApi), new OrchestratorImpl(configServerApi), Optional.of(configServerApi));
+ private RealConfigServerClients(ConfigServerApi configServerApi) {
+ this(new RealNodeRepository(configServerApi), new OrchestratorImpl(configServerApi), Optional.of(configServerApi));
}
- private ConfigServerClientsImpl(NodeRepository nodeRepository, Orchestrator orchestrator,
+ private RealConfigServerClients(NodeRepository nodeRepository, Orchestrator orchestrator,
Optional<ConfigServerApi> configServerApi) {
this.nodeRepository = nodeRepository;
this.orchestrator = orchestrator;
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepositoryImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java
index f2152dffc0c..5b22866fa15 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepositoryImpl.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java
@@ -1,18 +1,20 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.node.admin.configserver.noderepository;
-import com.yahoo.vespa.hosted.node.admin.ContainerAclSpec;
-import com.yahoo.vespa.hosted.node.admin.ContainerNodeSpec;
import com.yahoo.vespa.hosted.dockerapi.ContainerName;
import com.yahoo.vespa.hosted.dockerapi.DockerImage;
+import com.yahoo.vespa.hosted.node.admin.ContainerAclSpec;
+import com.yahoo.vespa.hosted.node.admin.ContainerNodeSpec;
+import com.yahoo.vespa.hosted.node.admin.component.Environment;
import com.yahoo.vespa.hosted.node.admin.configserver.ConfigServerApi;
-import com.yahoo.vespa.hosted.node.admin.nodeagent.NodeAttributes;
+import com.yahoo.vespa.hosted.node.admin.configserver.HttpException;
+import com.yahoo.vespa.hosted.node.admin.configserver.SslConfigServerApiImpl;
import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.bindings.GetAclResponse;
import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.bindings.GetNodesResponse;
import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.bindings.NodeMessageResponse;
import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.bindings.UpdateNodeAttributesRequestBody;
import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.bindings.UpdateNodeAttributesResponse;
-import com.yahoo.vespa.hosted.node.admin.configserver.HttpException;
+import com.yahoo.vespa.hosted.node.admin.nodeagent.NodeAttributes;
import com.yahoo.vespa.hosted.node.admin.util.PrefixLogger;
import com.yahoo.vespa.hosted.provision.Node;
@@ -26,15 +28,19 @@ import java.util.stream.Collectors;
/**
* @author stiankri, dybis
*/
-public class NodeRepositoryImpl implements NodeRepository {
- private static final PrefixLogger NODE_ADMIN_LOGGER = PrefixLogger.getNodeAdminLogger(NodeRepositoryImpl.class);
+public class RealNodeRepository implements NodeRepository {
+ private static final PrefixLogger NODE_ADMIN_LOGGER = PrefixLogger.getNodeAdminLogger(RealNodeRepository.class);
private final ConfigServerApi configServerApi;
- public NodeRepositoryImpl(ConfigServerApi configServerApi) {
+ public RealNodeRepository(ConfigServerApi configServerApi) {
this.configServerApi = configServerApi;
}
+ public RealNodeRepository(Environment environment) {
+ this(new SslConfigServerApiImpl(environment));
+ }
+
@Override
public List<ContainerNodeSpec> getContainersToRun(String baseHostName) {
final GetNodesResponse nodesForHost = configServerApi.get(
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java
index 2d261195213..bc8a45f2dfb 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java
@@ -33,8 +33,6 @@ import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
-import static com.yahoo.vespa.defaults.Defaults.getDefaults;
-
/**
* Class that wraps the Docker class and have some tools related to running programs in docker.
*
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminMain.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminMain.java
index d19f64a2bc3..4b806c905d9 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminMain.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminMain.java
@@ -11,7 +11,7 @@ import com.yahoo.vespa.hosted.node.admin.component.AdminComponent;
import com.yahoo.vespa.hosted.node.admin.component.Environment;
import com.yahoo.vespa.hosted.node.admin.config.ConfigServerConfig;
import com.yahoo.vespa.hosted.node.admin.component.DockerAdminComponent;
-import com.yahoo.vespa.hosted.node.admin.configserver.ConfigServerClientsImpl;
+import com.yahoo.vespa.hosted.node.admin.configserver.RealConfigServerClients;
import com.yahoo.vespa.hosted.node.admin.provider.NodeAdminStateUpdater;
import java.io.File;
@@ -66,7 +66,7 @@ public class NodeAdminMain implements AutoCloseable {
docker,
metricReceiver,
classLocking,
- new ConfigServerClientsImpl(new Environment(configServerConfig)));
+ new RealConfigServerClients(new Environment(configServerConfig)));
}
logger.log(LogLevel.INFO, () -> {
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAttributes.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAttributes.java
index 2d93dff80a4..4851ad71ebb 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAttributes.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAttributes.java
@@ -6,7 +6,7 @@ import com.yahoo.vespa.hosted.dockerapi.DockerImage;
import java.util.Objects;
-// It somewhat sucks that this class almost duplicates a binding class used by NodeRepositoryImpl,
+// It somewhat sucks that this class almost duplicates a binding class used by RealNodeRepository,
// but using the binding class here would be a layer violation, and would also tie this logic to
// serialization-related dependencies it needs not have.
public class NodeAttributes {
diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepositoryImplTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java
index 85e101714e8..fb3416615da 100644
--- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepositoryImplTest.java
+++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java
@@ -34,7 +34,7 @@ import static org.junit.Assert.fail;
*
* @author dybdahl
*/
-public class NodeRepositoryImplTest {
+public class RealNodeRepositoryTest {
private JDisc container;
private ConfigServerApiImpl configServerApi;
@@ -74,7 +74,7 @@ public class NodeRepositoryImplTest {
private void waitForJdiscContainerToServe() throws InterruptedException {
Instant start = Instant.now();
- NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi);
+ NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi);
while (Instant.now().minusSeconds(120).isBefore(start)) {
try {
nodeRepositoryApi.getContainersToRun("foobar");
@@ -96,7 +96,7 @@ public class NodeRepositoryImplTest {
@Test
public void testGetContainersToRunApi() throws InterruptedException {
waitForJdiscContainerToServe();
- NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi);
+ NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi);
String dockerHostHostname = "dockerhost1.yahoo.com";
final List<ContainerNodeSpec> containersToRun = nodeRepositoryApi.getContainersToRun(dockerHostHostname);
@@ -115,7 +115,7 @@ public class NodeRepositoryImplTest {
@Test
public void testGetContainer() throws InterruptedException, IOException {
waitForJdiscContainerToServe();
- NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi);
+ NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi);
String hostname = "host4.yahoo.com";
Optional<ContainerNodeSpec> nodeSpec = nodeRepositoryApi.getContainerNodeSpec(hostname);
assertThat(nodeSpec.isPresent(), is(true));
@@ -125,7 +125,7 @@ public class NodeRepositoryImplTest {
@Test
public void testGetContainerForNonExistingNode() throws InterruptedException, IOException {
waitForJdiscContainerToServe();
- NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi);
+ NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi);
String hostname = "host-that-does-not-exist";
Optional<ContainerNodeSpec> nodeSpec = nodeRepositoryApi.getContainerNodeSpec(hostname);
assertFalse(nodeSpec.isPresent());
@@ -134,7 +134,7 @@ public class NodeRepositoryImplTest {
@Test
public void testUpdateNodeAttributes() throws InterruptedException, IOException {
waitForJdiscContainerToServe();
- NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi);
+ NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi);
String hostname = "host4.yahoo.com";
nodeRepositoryApi.updateNodeAttributes(
hostname,
@@ -147,7 +147,7 @@ public class NodeRepositoryImplTest {
@Test(expected = RuntimeException.class)
public void testUpdateNodeAttributesWithBadValue() throws InterruptedException, IOException {
waitForJdiscContainerToServe();
- NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi);
+ NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi);
String hostname = "host4.yahoo.com";
nodeRepositoryApi.updateNodeAttributes(
hostname,
@@ -159,7 +159,7 @@ public class NodeRepositoryImplTest {
@Test
public void testMarkAsReady() throws InterruptedException, IOException {
- NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi);
+ NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi);
waitForJdiscContainerToServe();
nodeRepositoryApi.markAsDirty("host5.yahoo.com");
diff --git a/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.cpp b/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.cpp
index 466df61f8d0..a0cc89d15c6 100644
--- a/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.cpp
+++ b/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.cpp
@@ -5,6 +5,43 @@
#include <vespa/log/log.h>
LOG_SETUP(".queryperf");
+namespace {
+
+struct MyLogTask : vespalib::Executor::Task {
+ uint32_t queueLen;
+ uint32_t activeCnt;
+ uint32_t queryCnt;
+ uint32_t dropCnt;
+ uint32_t timeoutCnt;
+ double avgQueryTime;
+ MyLogTask(uint32_t queueLen_in,
+ uint32_t activeCnt_in,
+ uint32_t queryCnt_in,
+ uint32_t dropCnt_in,
+ uint32_t timeoutCnt_in,
+ double avgQueryTime_in)
+ : queueLen(queueLen_in),
+ activeCnt(activeCnt_in),
+ queryCnt(queryCnt_in),
+ dropCnt(dropCnt_in),
+ timeoutCnt(timeoutCnt_in),
+ avgQueryTime(avgQueryTime_in)
+ {
+ }
+ void run() override {
+ EV_VALUE("queued_queries", queueLen);
+ EV_VALUE("active_queries", activeCnt);
+ EV_COUNT("queries", queryCnt);
+ EV_COUNT("dropped_queries", dropCnt);
+ EV_COUNT("timedout_queries", timeoutCnt);
+ if (avgQueryTime > 0.0) {
+ EV_VALUE("query_eval_time_avg_s", avgQueryTime);
+ }
+ }
+};
+
+} // namespace <unnamed>
+
FastS_QueryPerf::FastS_QueryPerf()
: queueLen(0),
activeCnt(0),
@@ -28,19 +65,20 @@ FastS_QueryPerf::reset()
timeoutCnt = 0;
}
-void
-FastS_QueryPerf::log()
+vespalib::Executor::Task::UP
+FastS_QueryPerf::make_log_task()
{
- EV_VALUE("queued_queries", queueLen);
- EV_VALUE("active_queries", activeCnt);
- EV_COUNT("queries", queryCnt);
- EV_COUNT("dropped_queries", dropCnt);
- EV_COUNT("timedout_queries", timeoutCnt);
+ double avgQueryTime = 0.0;
if (queryCnt > _lastQueryCnt) {
- double avgQueryTime = (queryTime - _lastQueryTime)
- / ((double)(queryCnt - _lastQueryCnt));
- EV_VALUE("query_eval_time_avg_s", avgQueryTime);
+ avgQueryTime = (queryTime - _lastQueryTime)
+ / ((double)(queryCnt - _lastQueryCnt));
}
_lastQueryCnt = queryCnt;
_lastQueryTime = queryTime;
+ return std::make_unique<MyLogTask>(queueLen,
+ activeCnt,
+ queryCnt,
+ dropCnt,
+ timeoutCnt,
+ avgQueryTime);
}
diff --git a/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.h b/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.h
index c4f20bc3cef..ee31a8e58b2 100644
--- a/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.h
+++ b/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.h
@@ -3,6 +3,7 @@
#pragma once
#include <cstdint>
+#include <vespa/vespalib/util/executor.h>
struct FastS_QueryPerf
{
@@ -20,7 +21,7 @@ struct FastS_QueryPerf
* prepare the object for reuse logging wise.
**/
void reset();
- void log();
+ vespalib::Executor::Task::UP make_log_task();
private:
uint32_t _lastQueryCnt;
diff --git a/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp b/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp
index d9f2b4ecd4f..b68566c3c9b 100644
--- a/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp
+++ b/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp
@@ -92,6 +92,7 @@ Fdispatch::~Fdispatch()
LOG(debug, "Will close threadpool");
_mypool->Close();
+ _executor.shutdown().sync();
LOG(debug, "Has closed threadpool");
_transportServer.reset();
_engineAdapter.reset();
@@ -194,7 +195,8 @@ Fdispatch::CheckTempFail()
* Set up stuff as specified in the fdispatch-rc-file.
*/
Fdispatch::Fdispatch(const config::ConfigUri &configUri)
- : _mypool(),
+ : _executor(1, 128 * 1024),
+ _mypool(),
_engineAdapter(),
_transportServer(),
_componentConfig(),
@@ -391,7 +393,7 @@ Fdispatch::Init()
void
Fdispatch::logPerformance()
{
- _nodeManager->logPerformance();
+ _nodeManager->logPerformance(_executor);
}
uint32_t
diff --git a/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.h b/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.h
index a0294e22655..6cfb4bfb5a1 100644
--- a/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.h
+++ b/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.h
@@ -11,6 +11,7 @@
#include <vespa/config/helper/configfetcher.h>
#include <vespa/vespalib/net/simple_component_config_producer.h>
#include <vespa/vespalib/util/random.h>
+#include <vespa/vespalib/util/threadstackexecutor.h>
class FastS_NodeManager;
class FastS_fdispatch_RPC;
@@ -62,6 +63,7 @@ private:
Fdispatch(const Fdispatch &);
Fdispatch& operator=(const Fdispatch &);
+ vespalib::ThreadStackExecutor _executor;
std::unique_ptr<FastOS_ThreadPool> _mypool;
std::unique_ptr<EngineAdapter> _engineAdapter;
std::unique_ptr<TransportServer> _transportServer;
diff --git a/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.cpp b/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.cpp
index 4b272a615a6..302f92cef39 100644
--- a/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.cpp
+++ b/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.cpp
@@ -391,7 +391,7 @@ FastS_NodeManager::getChildInfo()
void
-FastS_NodeManager::logPerformance()
+FastS_NodeManager::logPerformance(vespalib::Executor &executor)
{
_queryPerf.reset();
FastS_DataSetCollection *dsc = GetDataSetCollection();
@@ -403,7 +403,7 @@ FastS_NodeManager::logPerformance()
}
dsc->subRef();
- _queryPerf.log();
+ executor.execute(_queryPerf.make_log_task());
}
diff --git a/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.h b/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.h
index e0396b46748..77d4482fba7 100644
--- a/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.h
+++ b/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.h
@@ -8,6 +8,7 @@
#include <vespa/searchcore/fdispatch/common/queryperf.h>
#include <vespa/vespalib/net/simple_component_config_producer.h>
#include <vespa/config/subscription/configuri.h>
+#include <vespa/vespalib/util/executor.h>
#include <mutex>
using vespa::config::search::core::PartitionsConfig;
@@ -92,7 +93,7 @@ public:
* log query performance. This method should only be invoked from
* the FNET thread.
**/
- void logPerformance();
+ void logPerformance(vespalib::Executor &executor);
void CheckEvents(FastS_TimeKeeper *timeKeeper); // invoked by FNET thread
};
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index 2e2858da238..262aba89f27 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression;
import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.text.Utf8;
@@ -11,9 +12,9 @@ import java.security.NoSuchAlgorithmException;
import java.util.*;
/**
- * <p>A function defined by a ranking expression</p>
+ * A function defined by a ranking expression
*
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
* @author bratseth
*/
public class ExpressionFunction {
@@ -23,7 +24,7 @@ public class ExpressionFunction {
private final RankingExpression body;
/**
- * <p>Constructs a new function</p>
+ * Constructs a new function
*
* @param name the name of this function
* @param arguments its argument names
@@ -43,28 +44,27 @@ public class ExpressionFunction {
public RankingExpression getBody() { return body; }
/**
- * <p>Create and return an instance of this function based on the given
- * arguments. If function calls are nested, this call might produce
- * additional scripts.</p>
+ * Creates and returns an instance of this function based on the given
+ * arguments. If function calls are nested, this call may produce
+ * additional functions.
*
* @param context the context used to expand this
- * @param arguments the arguments to instantiate on.
+ * @param argumentValues the arguments to instantiate on.
* @param path the expansion path leading to this.
* @return the script function instance created.
*/
- public Instance expand(SerializationContext context, List<ExpressionNode> arguments, Deque<String> path) {
+ public Instance expand(SerializationContext context, List<ExpressionNode> argumentValues, Deque<String> path) {
Map<String, String> argumentBindings = new HashMap<>();
- for (int i = 0; i < this.arguments.size() && i < arguments.size(); ++i) {
- argumentBindings.put(this.arguments.get(i), arguments.get(i).toString(context, path, null));
+ for (int i = 0; i < arguments.size() && i < arguments.size(); ++i) {
+ argumentBindings.put(arguments.get(i), argumentValues.get(i).toString(context, path, null));
}
- return new Instance(toSymbol(argumentBindings), body.getRoot().toString(context.createBinding(argumentBindings), path, null));
+ return new Instance(toSymbol(argumentBindings), body.getRoot().toString(context.withBindings(argumentBindings), path, null));
}
/**
* Returns a symbolic string that represents this function with a given
* list of arguments. The arguments are mangled by hashing the string
- * representation of the argument expressions, so we might need to revisit
- * this if we start seeing collisions.
+ * representation of the argument expressions.
*
* @param argumentBindings the bound arguments to include in the symbolic name.
* @return the symbolic name for an instance of this function
@@ -85,8 +85,8 @@ public class ExpressionFunction {
/**
- * <p>Returns a more unique hash code than what Java's own {@link
- * String#hashCode()} method would produce.</p>
+ * Returns a more unique hash code than what Java's own {@link
+ * String#hashCode()} method would produce.
*
* @param str The string to hash.
* @return A 64 bit long hash code.
@@ -136,4 +136,5 @@ public class ExpressionFunction {
}
}
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java
index 49466f1974d..f0532d9d433 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java
@@ -91,8 +91,8 @@ public class FeatureList implements Iterable<ReferenceNode> {
/**
* Returns the feature at the given index.
*
- * @param i The index of the feature to return.
- * @return The featuer at the given index.
+ * @param i the index of the feature to return.
+ * @return the feature at the given index.
*/
public ReferenceNode get(int i) {
return features.get(i);
@@ -137,4 +137,5 @@ public class FeatureList implements Iterable<ReferenceNode> {
public Iterator<ReferenceNode> iterator() {
return features.iterator();
}
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
index c8d90e8c4e8..6b2422d7cb2 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
@@ -244,10 +244,6 @@ public class RankingExpression implements Serializable {
* @return a list of named rank properties required to implement this expression.
*/
public Map<String, String> getRankProperties(List<ExpressionFunction> macros) {
- Map<String, ExpressionFunction> arg = new HashMap<>();
- for (ExpressionFunction function : macros) {
- arg.put(function.getName(), function);
- }
Deque<String> path = new LinkedList<>();
SerializationContext context = new SerializationContext(macros);
String serializedRoot = root.toString(context, path, null);
@@ -272,7 +268,7 @@ public class RankingExpression implements Serializable {
*
* @throws IllegalArgumentException if this expression is not type correct in this context
*/
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return root.type(context);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java
new file mode 100644
index 00000000000..6277721e8f5
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java
@@ -0,0 +1,121 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression;
+
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Deque;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * A reference to a feature, function, or value in ranking expressions
+ *
+ * @author bratseth
+ */
+public class Reference extends TypeContext.Name {
+
+ private final String name;
+ private final Arguments arguments;
+
+ /**
+ * The output, or null if none
+ */
+ private final String output;
+
+ public Reference(String name, Arguments arguments, String output) {
+ super(name);
+ Objects.requireNonNull(name, "name cannot be null");
+ Objects.requireNonNull(arguments, "arguments cannot be null");
+ this.name = name;
+ this.arguments = arguments;
+ this.output = output;
+ }
+
+ public String name() { return name; }
+
+ public Arguments arguments() { return arguments; }
+
+ public String output() { return output; }
+
+ /**
+ * Creates a reference to a simple feature consisting of a name and a single argument
+ */
+ public static Reference simple(String name, String argumentValue) {
+ return new Reference(name,
+ new Arguments(new ReferenceNode(argumentValue)),
+ null);
+ }
+
+ /**
+ * Returns the given simple feature as a reference, or empty if it is not a valid simple
+ * feature string on the form name(argument).
+ */
+ public static Optional<Reference> simple(String feature) {
+ int startParenthesis = feature.indexOf('(');
+ if (startParenthesis < 0)
+ return Optional.empty();
+ int endParenthesis = feature.lastIndexOf(')');
+ String featureName = feature.substring(0, startParenthesis);
+ if (startParenthesis < 1 || endParenthesis < startParenthesis) return Optional.empty();
+ String argument = feature.substring(startParenthesis + 1, endParenthesis);
+ if (argument.startsWith("'") || argument.startsWith("\""))
+ argument = argument.substring(1);
+ if (argument.endsWith("'") || argument.endsWith("\""))
+ argument = argument.substring(0, argument.length() - 1);
+ return Optional.of(simple(featureName, argument));
+ }
+
+ /**
+ * Returns whether this is a simple identifier - no arguments or output
+ */
+ public boolean isIdentifier() {
+ return this.arguments.expressions().size() == 0 && output == null;
+ }
+
+ public Reference withArguments(Arguments arguments) {
+ return new Reference(name, arguments, output);
+ }
+
+ public Reference withOutput(String output) {
+ return new Reference(name, arguments, output);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == this) return true;
+ if (!(o instanceof Reference)) return false;
+ Reference other = (Reference) o;
+ if (!Objects.equals(other.name, this.name)) return false;
+ if (!Objects.equals(other.arguments, this.arguments)) return false;
+ if (!Objects.equals(other.output, this.output)) return false;
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(name, arguments, output);
+ }
+
+ @Override
+ public String toString() {
+ return toString(new SerializationContext(), null, null);
+ }
+
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ StringBuilder b = new StringBuilder(name);
+ if (arguments != null && arguments.expressions().size() > 0)
+ b.append("(").append(arguments.expressions().stream()
+ .map(node -> node.toString(context, path, parent))
+ .collect(Collectors.joining(","))).append(")");
+ if (output != null)
+ b.append(".").append(output);
+ return b.toString();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
index 5f8daa69ecf..ee5952d9aea 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
@@ -82,8 +83,8 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
}
@Override
- public TensorType getType(String name) {
- Integer index = nameToIndex().get(name);
+ public TensorType getType(Reference reference) {
+ Integer index = nameToIndex().get(reference.toString());
if (index == null) return null;
return values[index].type();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 861f9565d66..4e046df11ca 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -1,9 +1,11 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Set;
@@ -14,7 +16,7 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
-public abstract class Context implements EvaluationContext {
+public abstract class Context implements EvaluationContext<Reference> {
/**
* Returns the value of a simple variable name.
@@ -24,6 +26,11 @@ public abstract class Context implements EvaluationContext {
*/
public abstract Value get(String name);
+ @Override
+ public TensorType getType(String reference) {
+ throw new UnsupportedOperationException("Not able to parse gereral references from string form");
+ }
+
/** Returns a variable as a tensor */
@Override
public Tensor getTensor(String name) { return get(name).asTensor(); }
@@ -46,6 +53,7 @@ public abstract class Context implements EvaluationContext {
* calculation to output several), or null to output the
* "main" (or only) value.
*/
+ // TODO: Remove/change to use reference?
public Value get(String name, Arguments arguments, String output) {
if (arguments != null && arguments.expressions().size() > 0)
name = name + "(" + arguments.expressions().stream().map(ExpressionNode::toString).collect(Collectors.joining(",")) + ")";
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
index 0625e8506cc..0004036da4b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
/**
@@ -68,7 +69,9 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext {
}
@Override
- public TensorType getType(String name) { return TensorType.empty; }
+ public TensorType getType(Reference reference) {
+ return TensorType.empty; // Double only
+ }
/** Perform a slow lookup by name */
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
index a81d0c89f8f..4ef24d60bba 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import java.util.Collections;
@@ -15,7 +16,7 @@ import java.util.Set;
*/
public class MapContext extends Context {
- private Map<String, Value> bindings = new HashMap<>();
+ private Map<String, Value> bindings = new HashMap<>(); // TODO: Change String to Reference
private boolean frozen = false;
@@ -42,8 +43,8 @@ public class MapContext extends Context {
/** Returns the type of the given value key, or null if it is not bound. */
@Override
- public TensorType getType(String key) {
- Value value = bindings.get(key);
+ public TensorType getType(Reference key) {
+ Value value = bindings.get(key.toString());
if (value == null) return null;
return value.type();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java
new file mode 100644
index 00000000000..2a42e2d92f7
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java
@@ -0,0 +1,38 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A context which only contains type information.
+ *
+ * @author bratseth
+ */
+public class MapTypeContext implements TypeContext<Reference> {
+
+ private final Map<Reference, TensorType> featureTypes = new HashMap<>();
+
+ public void setType(Reference reference, TensorType type) {
+ featureTypes.put(reference, type);
+ }
+
+ @Override
+ public TensorType getType(String reference) {
+ throw new UnsupportedOperationException("Not able to parse gereral references from string form");
+ }
+
+ @Override
+ public TensorType getType(Reference reference) {
+ return featureTypes.get(reference);
+ }
+
+ /** Returns an unmodifiable map of the bindings in this */
+ public Map<Reference, TensorType> bindings() { return Collections.unmodifiableMap(featureTypes); }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java
deleted file mode 100644
index ff2088263d8..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-
-/**
- * A context which only contains type information.
- *
- * @author bratseth
- */
-public class TypeMapContext implements TypeContext {
-
- private final Map<String, TensorType> featureTypes = new HashMap<>();
-
- public void setType(String name, TensorType type) {
- featureTypes.put(name, type);
- }
-
- @Override
- public TensorType getType(String name) {
- return featureTypes.get(name);
- }
-
- /** Returns an unmodifiable map of the bindings in this */
- public Map<String, TensorType> bindings() { return Collections.unmodifiableMap(featureTypes); }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
index 8ee4cdbf297..649c70122f1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -26,7 +27,7 @@ public class GBDTForestNode extends ExpressionNode {
}
@Override
- public final TensorType type(TypeContext context) { return TensorType.empty; }
+ public final TensorType type(TypeContext<Reference> context) { return TensorType.empty; }
@Override
public final Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
index aac635b2545..53a286f09f6 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -51,7 +52,7 @@ public final class GBDTNode extends ExpressionNode {
public final double[] values() { return values; }
@Override
- public final TensorType type(TypeContext context) { return TensorType.empty; }
+ public final TensorType type(TypeContext<Reference> context) { return TensorType.empty; }
@Override
public final Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java
deleted file mode 100644
index 5f0c016881a..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java
+++ /dev/null
@@ -1,132 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.tensor.IndexedTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.TensorProto;
-import org.tensorflow.framework.TensorShapeProto;
-
-/**
- * @author lesters
- */
-public class AttrValueConverter {
-
- public static Tensor toVespaTensor(NodeDef tfNode, String attr) {
- if (!tfNode.getAttrMap().containsKey(attr)) {
- throw new IllegalArgumentException(tfNode.getName() + " has no attribute called " + attr);
- }
- AttrValue attrValue = tfNode.getAttrMap().get(attr);
- switch (attrValue.getValueCase()) {
- case TENSOR:
- return buildFromTensor(attrValue);
- case B:
- return buildFromSingleValue(attrValue.getB() ? 1.0 : 0.0);
- case F:
- return buildFromSingleValue(attrValue.getF());
- case I:
- return buildFromSingleValue(attrValue.getI());
- }
-
- throw new IllegalArgumentException(tfNode.getName() +
- ": unsupported attribute type: '" + attrValue.getValueCase().toString() + "'");
- }
-
- private static Tensor buildFromSingleValue(double value) {
- TensorType type = new TensorType.Builder().build();
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
- builder.cellByDirectIndex(0, value);
- return builder.build();
- }
-
- private static Tensor buildFromTensor(AttrValue attrValue) {
- TensorProto tensorProto = attrValue.getTensor();
- TensorType type = toVespaTensorType(tensorProto.getTensorShape());
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
- Values values = valuesOf(tensorProto);
- for (int i = 0; i < values.size(); ++i) {
- builder.cellByDirectIndex(i, values.get(i));
- }
- Tensor tensor = builder.build();
- return tensor;
- }
-
- private static Values valuesOf(TensorProto tensorProto) {
- switch (tensorProto.getDtype()) {
- case DT_BOOL:
- return new BoolValues(tensorProto);
- case DT_HALF:
- return new HalfValues(tensorProto);
- case DT_INT16:
- case DT_INT32:
- return new IntValues(tensorProto);
- case DT_INT64:
- return new Int64Values(tensorProto);
- case DT_FLOAT:
- return new FloatValues(tensorProto);
- case DT_DOUBLE:
- return new DoubleValues(tensorProto);
- }
-
- throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
- }
-
- public static TensorType toVespaTensorType(TensorShapeProto shapeProto) {
- TensorType.Builder b = new TensorType.Builder();
- for (TensorShapeProto.Dim dimension : shapeProto.getDimList()) {
- int dimensionSize = (int)dimension.getSize();
- if (dimensionSize >= 0)
- b.indexed("d" + b.rank(), dimensionSize);
- else
- b.indexed("d" + b.rank()); // unbound size
- }
- return b.build();
- }
-
- private static abstract class Values {
- protected final TensorProto tensorProto;
- protected Values(TensorProto tensorProto) { this.tensorProto = tensorProto; }
- abstract double get(int i);
- abstract int size();
- }
-
- private static class BoolValues extends Values {
- BoolValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; }
- @Override int size() { return tensorProto.getBoolValCount(); }
- }
-
- private static class HalfValues extends Values {
- HalfValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getHalfVal(i); }
- @Override int size() { return tensorProto.getHalfValCount(); }
- }
-
- private static class IntValues extends Values {
- IntValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getIntVal(i); }
- @Override int size() { return tensorProto.getIntValCount(); }
- }
-
- private static class Int64Values extends Values {
- Int64Values(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getInt64Val(i); }
- @Override int size() { return tensorProto.getInt64ValCount(); }
- }
-
- private static class FloatValues extends Values {
- FloatValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getFloatVal(i); }
- @Override int size() { return tensorProto.getFloatValCount(); }
- }
-
- private static class DoubleValues extends Values {
- DoubleValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getDoubleVal(i); }
- @Override int size() { return tensorProto.getDoubleValCount(); }
- }
-
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
deleted file mode 100644
index ef82045e771..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ /dev/null
@@ -1,715 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.google.common.collect.ImmutableList;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
-import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.VariableTensor;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.Matmul;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.Softmax;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.Session;
-import org.tensorflow.framework.AttrValue;
-
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Optional;
-import java.util.function.DoubleBinaryOperator;
-import java.util.function.DoubleUnaryOperator;
-import java.util.function.Function;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-
-/**
- * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions.
- *
- * @author bratseth
- * @author lesters
- */
-class OperationMapper {
-
- // A note on conversion from implicitly numbered to explicitly named dimensions:
- //
- // Vespa tensor dimensions are explicitly named and thus have an explicit notion of being
- // 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation
- // comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation
- // around dimension renaming operations which mirrors those built into the TF operation definitions.
- //
- // To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost'
- // dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation
- // and the result is then renamed again (if necessary) to recover this convention across a full nested
- // computation.
- //
- // This requires us to track tensor types throughout the conversion.
-
-
- // Supported TensorFlow operations
- enum Operation {
-
- // TODO: move the implementations to specific files as we support more operations
-
- /*
- * array ops
- */
- CONST (OperationMapper::constant),
- EXPANDDIMS (OperationMapper::expandDims),
- IDENTITY (OperationMapper::identity),
- PLACEHOLDER (OperationMapper::placeholder),
- PLACEHOLDERWITHDEFAULT (OperationMapper::placeholderWithDefault),
- RESHAPE (OperationMapper::reshape),
- SQUEEZE (OperationMapper::squeeze),
-
- /*
- * control flow
- */
- MERGE (OperationMapper::merge),
- SWITCH (OperationMapper::switchOp),
-
- /*
- * math ops
- */
- ADD (OperationMapper::add),
- ADD_N (OperationMapper::add),
- ACOS (OperationMapper::acos),
- DIV (OperationMapper::div),
- REALDIV (OperationMapper::div),
- FLOOR (OperationMapper::floor),
- MATMUL (OperationMapper::matmul),
- MAXIMUM (OperationMapper::maximum),
- MEAN (OperationMapper::mean),
- REDUCEMEAN (OperationMapper::mean),
- MUL (OperationMapper::mul),
- MULTIPLY (OperationMapper::mul),
- RSQRT (OperationMapper::rsqrt),
- SELECT (OperationMapper::select),
- WHERE3 (OperationMapper::select),
- SIGMOID (OperationMapper::sigmoid),
- SQUAREDDIFFERENCE (OperationMapper::squaredDifference),
- SUB (OperationMapper::sub),
- SUBTRACT (OperationMapper::sub),
-
- /*
- * nn ops
- */
- BIASADD (OperationMapper::add),
- ELU (OperationMapper::elu),
- RELU (OperationMapper::relu),
- SELU (OperationMapper::selu),
- SOFTMAX (OperationMapper::softMax),
-
- /*
- * state ops
- */
- VARIABLE (OperationMapper::variable),
- VARIABLEV2 (OperationMapper::variable),
-
- /*
- * evaluation no-ops
- */
- STOPGRADIENT (OperationMapper::identity),
- NOOP (OperationMapper::noOp);
-
-
- private final Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func;
-
- Operation(Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func) {
- this.func = func;
- }
-
- Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) {
- return func.apply(params);
- }
-
- }
-
- static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) {
- Optional<Operation> operation = Stream.of(Operation.values())
- .filter(op -> op.name().equalsIgnoreCase(params.node().getOp()))
- .findFirst();
- if (operation.isPresent()) {
- return operation.get().map(params);
- }
- params.signature().importWarning("TensorFlow operation '" + params.node().getOp() +
- "' in node '" + params.node().getName() + "' is not supported.");
- return Optional.empty();
- }
-
-
- // Operations ---------------------------------
-
- private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) {
- Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value");
- if (value.type().rank() == 0) {
- TypedTensorFunction output = new TypedTensorFunction(value.type(),
- new TensorFunctionNode.TensorFunctionExpressionNode(
- new ConstantNode(new DoubleValue(value.asDouble()))));
- return Optional.of(output);
- }
- return createConstant(params, value);
- }
-
- private static Optional<TypedTensorFunction> expandDims(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 2)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
-
- Tensor axis = getConstantTensor(params, params.node().getInput(1));
- if (axis.type().rank() != 0) {
- throw new IllegalArgumentException("Axis argument to ExpandDims must be a scalar");
- }
-
- TensorFunction inputFunction = inputs.get(0).get().function();
- TensorType inputType = inputs.get(0).get().type();
-
- int dimensionToInsert = (int)axis.asDouble();
- if (dimensionToInsert < 0) {
- dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
- }
-
- TensorType.Builder outputTypeBuilder = new TensorType.Builder();
- int dimensionIndex = 0;
- for (int i = 0; i < inputType.dimensions().size() + 1; ++i) {
- String name = String.format("temp_%d", i);
- Long size;
- if (i == dimensionToInsert) {
- size = 1L;
- } else {
- size = dimensionSize(inputType.dimensions().get(dimensionIndex));
- dimensionIndex++;
- }
- outputTypeBuilder.indexed(name, size);
- }
-
- return reshape(inputFunction, inputType, outputTypeBuilder.build());
- }
-
- private static Optional<TypedTensorFunction> identity(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 1)) {
- return Optional.empty();
- }
- return params.inputs().get(0);
- }
-
- private static Optional<TypedTensorFunction> placeholder(TensorFlowImporter.Parameters params) {
- String name = params.node().getName();
- String vespaName = toVespaName(params.node().getName());
- TensorType type = params.result().arguments().get(name);
- if (type == null) {
- throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name +
- "', but there is no such placeholder");
- }
- params.result().requiredMacro(vespaName, type);
- // Included literally in the expression and so must be produced by a separate macro in the rank profile
- TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(vespaName, type));
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> placeholderWithDefault(TensorFlowImporter.Parameters params) {
- String name = toVespaName(params.node().getInput(0));
- Tensor defaultValue = getConstantTensor(params, params.node().getInput(0));
- params.result().largeConstant(name, defaultValue);
- params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")")));
- // The default value will be provided by the macro. Users can override macro to change value.
- TypedTensorFunction output = new TypedTensorFunction(defaultValue.type(), new VariableTensor(name));
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> reshape(TensorFlowImporter.Parameters params) {
- if ( ! checkInputs(params, 2)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- Tensor shape = getConstantTensor(params, params.node().getInput(1));
-
- TensorFunction inputFunction = inputs.get(0).get().function();
- TensorType inputType = inputs.get(0).get().type();
-
- TensorType.Builder outputTypeBuilder = new TensorType.Builder();
- int dimensionIndex = 0;
- for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
- Tensor.Cell cell = cellIterator.next();
- int size = cell.getValue().intValue();
- if (size < 0) {
- size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType).intValue();
- }
- outputTypeBuilder.indexed(String.format("temp_%d", dimensionIndex), size);
- dimensionIndex++;
- }
- return reshape(inputFunction, inputType, outputTypeBuilder.build());
- }
-
- private static Optional<TypedTensorFunction> squeeze(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 1)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
-
- TensorFunction inputFunction = inputs.get(0).get().function();
- TensorType inputType = inputs.get(0).get().type();
- List<String> squeezeDimensions;
-
- AttrValue squeezeDimsAttr = params.node().getAttrMap().get("squeeze_dims");
- if (squeezeDimsAttr == null) {
- squeezeDimensions = inputType.dimensions().stream().
- filter(dim -> dimensionSize(dim) == 1).
- map(TensorType.Dimension::name).
- collect(Collectors.toList());
- } else {
- squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().
- map(i -> i < 0 ? inputType.dimensions().size() - i : i).
- map(i -> inputType.dimensions().get(i.intValue())).
- filter(dim -> dimensionSize(dim) == 1).
- map(TensorType.Dimension::name).
- collect(Collectors.toList());
- }
-
- if (squeezeDimensions.isEmpty()) {
- return inputs.get(0);
- }
-
- TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
- TensorType outputType = Reduce.outputType(inputType, squeezeDimensions);
- TypedTensorFunction output = checkNamingConvention(outputType, outputFunction);
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> merge(TensorFlowImporter.Parameters params) {
- return params.inputs().stream()
- .filter(Optional::isPresent)
- .findFirst()
- .orElse(Optional.empty());
- }
-
- private static Optional<TypedTensorFunction> switchOp(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 2)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- Tensor predicate = getConstantTensor(params, params.node().getInput(1));
- if (predicate.type().rank() != 0) {
- throw new IllegalArgumentException("'switch': predicate must be a scalar");
- }
- double pred = predicate.asDouble();
- int output = params.port().length() > 0 ? Integer.parseInt(params.port()) : 0;
- if (output < 0 || output > 1) {
- throw new IllegalArgumentException("'switch': predicate is not boolean");
- }
- if (pred == output) {
- return inputs.get(0);
- }
- return Optional.empty();
- }
-
- private static Optional<TypedTensorFunction> add(TensorFlowImporter.Parameters params) {
- return join(params, ScalarFunctions.add());
- }
-
- private static Optional<TypedTensorFunction> acos(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.acos());
- }
-
- private static Optional<TypedTensorFunction> div(TensorFlowImporter.Parameters params) {
- return join(params, ScalarFunctions.divide());
- }
-
- private static Optional<TypedTensorFunction> floor(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.floor());
- }
-
- private static Optional<TypedTensorFunction> matmul(TensorFlowImporter.Parameters params) {
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- if (!checkInputs(params, 2)) {
- return Optional.empty();
- }
-
- TypedTensorFunction a = inputs.get(0).get();
- TypedTensorFunction b = inputs.get(1).get();
- if (a.type().rank() < 2 || b.type().rank() < 2)
- throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
- if (a.type().rank() != b.type().rank())
- throw new IllegalArgumentException("Tensors in matmul must have the same rank");
-
- String afterLastDim = "d" + (a.type().rank() + 1);
- // Let the first dimension of the second tensor be the same as the second dimension of the first
- // and the second dimension of the second argument be not present in the first argument, while leaving the
- // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication.
-
- // TODO: Check if transpose_a or transpose_b is set true and rename differently accordingly
-
- Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"),
- ImmutableList.of("d1", afterLastDim));
- Matmul matmul = new Matmul(a.function(), renamedB, "d1");
- TypedTensorFunction output = new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"),
- new Rename(matmul, afterLastDim, "d1"));
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> maximum(TensorFlowImporter.Parameters params) {
- return join(params, ScalarFunctions.max());
- }
-
- private static Optional<TypedTensorFunction> mean(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 2)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- TensorFunction inputFunction = inputs.get(0).get().function();
- TensorType inputType = inputs.get(0).get().type();
-
- Tensor reductionIndices = getConstantTensor(params, params.node().getInput(1));
- List<String> reduceDimensions = new ArrayList<>();
- for (Iterator<Tensor.Cell> cellIterator = reductionIndices.cellIterator(); cellIterator.hasNext();) {
- Tensor.Cell cell = cellIterator.next();
- int dimensionIndex = cell.getValue().intValue();
- if (dimensionIndex < 0) {
- dimensionIndex = inputType.dimensions().size() - dimensionIndex;
- }
- reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
- }
-
- TensorType outputType = Reduce.outputType(inputType, reduceDimensions);
- TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
-
- if (shouldKeepDimensions(params)) {
- return reshape(outputFunction, outputType, keepDimensionType(inputType, reduceDimensions));
- }
- TypedTensorFunction output = checkNamingConvention(outputType, outputFunction);
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> mul(TensorFlowImporter.Parameters params) {
- return join(params, ScalarFunctions.multiply());
- }
-
- private static Optional<TypedTensorFunction> rsqrt(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.rsqrt());
- }
-
- private static Optional<TypedTensorFunction> select(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 3)) {
- return Optional.empty();
- }
- Tensor condition = getConstantTensor(params, params.node().getInput(0));
-
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- TypedTensorFunction x = inputs.get(1).get();
- TypedTensorFunction y = inputs.get(2).get();
- if ((x.type().rank() != y.type().rank()) || !(tensorSize(x.type()).equals(tensorSize(y.type())))) {
- throw new IllegalArgumentException("'Select': input tensors must have the same shape");
- }
-
- if (condition.type().rank() == 0) {
- return Optional.of((int)condition.asDouble() == 0 ? y : x);
- }
- if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
- return Optional.of(condition.cellIterator().next().getValue().intValue() == 0 ? y : x);
- }
-
- // The task is to select cells from 'x' or 'y' based on 'condition'.
- // If 'condition' is 0 (false), select from 'y', if 1 (true) select
- // from 'x'. We do this by individually joining 'x' and 'y' with
- // 'condition', and then joining the resulting two tensors.
-
- Optional<TypedTensorFunction> conditionFunction = importConstantTensor(params, params.node().getInput(0));
- if (!conditionFunction.isPresent()) {
- return Optional.empty();
- }
- TensorFunction xCond = new Join(x.function(), conditionFunction.get().function(), ScalarFunctions.multiply());
- TensorFunction yCond = new Join(y.function(), conditionFunction.get().function(), new DoubleBinaryOperator() {
- @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); }
- @Override public String toString() { return "f(a,b)(a * (1-b))"; }
- });
- TensorFunction outputFunction = new Join(xCond, yCond, ScalarFunctions.add());
- TypedTensorFunction output = new TypedTensorFunction(x.type(), outputFunction);
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> sigmoid(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.sigmoid());
- }
-
- private static Optional<TypedTensorFunction> squaredDifference(TensorFlowImporter.Parameters params) {
- return join(params, ScalarFunctions.squareddifference());
- }
-
- private static Optional<TypedTensorFunction> sub(TensorFlowImporter.Parameters params) {
- return join(params, ScalarFunctions.subtract());
- }
-
- private static Optional<TypedTensorFunction> elu(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.elu());
- }
-
- private static Optional<TypedTensorFunction> relu(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.relu());
- }
-
- private static Optional<TypedTensorFunction> selu(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.selu());
- }
-
- private static Optional<TypedTensorFunction> softMax(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 1)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- TypedTensorFunction a = inputs.get(0).get();
- // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1
- String dimension = "d" + (a.type().rank() - 1);
- Softmax softmax = new Softmax(a.function(), dimension);
- TypedTensorFunction output = new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax);
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> variable(TensorFlowImporter.Parameters params) {
- return importConstantTensor(params, params.node().getName());
- }
-
- private static Optional<TypedTensorFunction> noOp(TensorFlowImporter.Parameters params) {
- return Optional.empty();
- }
-
- /*
- * Utility
- */
-
- private static Optional<TypedTensorFunction> join(TensorFlowImporter.Parameters params, DoubleBinaryOperator doubleFunction) {
- if (!checkInputs(params, 2)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
-
- TypedTensorFunction a = inputs.get(0).get();
- TypedTensorFunction b = inputs.get(1).get();
-
- if (a.type().rank() == 0 && b.type().rank() > 0) {
- return Optional.of(new TypedTensorFunction(b.type(), new Join(a.function(), b.function(), doubleFunction)));
- }
- if (b.type().rank() == 0 && a.type().rank() > 0) {
- return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction)));
- }
- if (a.type().rank() == b.type().rank()) {
- return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction)));
- }
-
- // Well now we have entered the wonderful world of "broadcasting"
- // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
- // I'm not able to extract from that any unambiguous specification of which dimensions
- // should be "stretched" when the tensor do not have the same number of dimensions.
- // From trying this with TensorFlow it appears that the second tensor is matched to the
- // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true.
- // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first).
-
- if (a.type().rank() > b.type().rank()) {
- TensorFunction renameFunction = renameForBroadcast(a, b);
- return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), renameFunction, doubleFunction)));
- }
- TensorFunction renameFunction = renameForBroadcast(b, a);
- return Optional.of(new TypedTensorFunction(b.type(), new Join(renameFunction, b.function(), doubleFunction)));
- }
-
- private static TensorFunction renameForBroadcast(TypedTensorFunction a, TypedTensorFunction b) {
- List<String> renameFrom = new ArrayList<>();
- List<String> renameTo = new ArrayList<>();
- int sizeDifference = a.type().rank() - b.type().rank();
- for (int i = 0; i < b.type().rank(); i++) {
- renameFrom.add(b.type().dimensions().get(i).name());
- renameTo.add("d" + (sizeDifference + i));
- }
- return new Rename(b.function(), renameFrom, renameTo);
- }
-
- private static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params, DoubleUnaryOperator doubleFunction) {
- if (!checkInputs(params, 1)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- TypedTensorFunction a = inputs.get(0).get();
- TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type());
- com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction);
- return Optional.of(new TypedTensorFunction(resultType, function));
- }
-
- private static Optional<TypedTensorFunction> createConstant(TensorFlowImporter.Parameters params, Tensor constant) {
- String name = toVespaName(params.node().getName());
- if (constant.type().rank() == 0 || constant.size() <= 1) {
- params.result().smallConstant(name, constant);
- } else {
- params.result().largeConstant(name, constant);
- }
- TypedTensorFunction output = new TypedTensorFunction(constant.type(),
- new TensorFunctionNode.TensorFunctionExpressionNode(
- new ReferenceNode("constant(\"" + name + "\")")));
- return Optional.of(output);
- }
-
- private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) {
- String vespaName = toVespaName(name);
- if (params.result().smallConstants().containsKey(vespaName)) {
- return params.result().smallConstants().get(vespaName);
- }
- if (params.result().largeConstants().containsKey(vespaName)) {
- return params.result().largeConstants().get(vespaName);
- }
- Session.Runner fetched = params.model().session().runner().fetch(name);
- List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
- if (importedTensors.size() != 1)
- throw new IllegalStateException("Expected 1 tensor from fetching " + name + ", but got " +
- importedTensors.size());
- return TensorConverter.toVespaTensor(importedTensors.get(0));
- }
-
- private static Optional<TypedTensorFunction> importConstantTensor(TensorFlowImporter.Parameters params, String name) {
- AttrValue shapes = params.node().getAttrMap().get("_output_shapes");
- if (shapes == null)
- throw new IllegalArgumentException("'" + name + "' is missing a tensor shape");
- Tensor constant = getConstantTensor(params, name);
- return createConstant(params, constant);
- }
-
- private static Optional<TypedTensorFunction> reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
- if (!tensorSize(inputType).equals(tensorSize(outputType))) {
- throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
- }
-
- // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
- // then use the dimension order of the new shape to roll back into a tensor.
- // Here we create a transformation tensor that is multiplied with the from tensor to map into
- // the new shape. We have to introduce temporary dimension names and rename back if dimension names
- // in the new and old tensor type overlap.
-
- ExpressionNode unrollFrom = unrollTensorExpression(inputType);
- ExpressionNode unrollTo = unrollTensorExpression(outputType);
- ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
-
- TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
- Generate transformTensor = new Generate(transformationType,
- new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
-
- TensorFunction outputFunction = new Reduce(
- new Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
- TypedTensorFunction output = checkNamingConvention(outputType, outputFunction);
- return Optional.of(output);
- }
-
- private static ExpressionNode unrollTensorExpression(TensorType type) {
- if (type.rank() == 0) {
- return new ConstantNode(DoubleValue.zero);
- }
- List<ExpressionNode> children = new ArrayList<>();
- List<ArithmeticOperator> operators = new ArrayList<>();
- int size = 1;
- for (int i = type.dimensions().size() - 1; i >= 0; --i) {
- TensorType.Dimension dimension = type.dimensions().get(i);
- children.add(0, new ReferenceNode(dimension.name()));
- if (size > 1) {
- operators.add(0, ArithmeticOperator.MULTIPLY);
- children.add(0, new ConstantNode(new DoubleValue(size)));
- }
- size *= dimensionSize(dimension);
- if (i > 0) {
- operators.add(0, ArithmeticOperator.PLUS);
- }
- }
- return new ArithmeticNode(children, operators);
- }
-
- private static boolean shouldKeepDimensions(TensorFlowImporter.Parameters params) {
- AttrValue keepDimsAttr = params.node().getAttrMap().get("keep_dims");
- return keepDimsAttr != null && keepDimsAttr.getB();
- }
-
- private static TensorType keepDimensionType(TensorType inputType, List<String> reduceDimensions) {
- TensorType.Builder builder = new TensorType.Builder();
- for (TensorType.Dimension dimension: inputType.dimensions()) {
- String name = dimension.name();
- Long size = dimensionSize(dimension);
- if (reduceDimensions.contains(name)) {
- size = 1L;
- }
- builder.indexed(name, size);
- }
- return builder.build();
- }
-
- private static TypedTensorFunction checkNamingConvention(TensorType type, TensorFunction function) {
- for (int i = 0; i < type.dimensions().size(); ++i) {
- String correct = String.format("d%d", i);
- String current = type.dimensions().get(i).name();
- if (!current.equals(correct)) {
- return fixNamingConvention(type, function);
- }
- }
- return new TypedTensorFunction(type, function);
- }
-
- private static TypedTensorFunction fixNamingConvention(TensorType type, TensorFunction function) {
- TensorType.Builder correctType = new TensorType.Builder();
- List<String> from = new ArrayList<>();
- List<String> to = new ArrayList<>();
- for (int i = 0; i < type.dimensions().size(); ++i) {
- String correct = String.format("d%d", i);
- String current = type.dimensions().get(i).name();
- if (!current.equals(correct)) {
- from.add(current);
- to.add(correct);
- }
- correctType.indexed(correct, dimensionSize(type.dimensions().get(i)));
- }
- if (from.size() > 0) {
- function = new Rename(function, from, to);
- type = correctType.build();
- }
- return new TypedTensorFunction(type, function);
- }
-
- private static Long tensorSize(TensorType type) {
- Long size = 1L;
- for (TensorType.Dimension dimension : type.dimensions()) {
- size *= dimensionSize(dimension);
- }
- return size;
- }
-
- private static Long dimensionSize(TensorType.Dimension dim) {
- return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
- }
-
- private static boolean checkInputs(TensorFlowImporter.Parameters params, int expected) {
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- if (!inputs.stream().allMatch(Optional::isPresent)) {
- return false;
- }
- if (inputs.size() != expected) {
- params.signature().importWarning("Expected " + expected +
- " arguments to " + params.node().getOp() + ", but got " + inputs.size());
- return false;
- }
- return true;
- }
-
- public static String toVespaName(String name) {
- return name != null ? name.replace('/', '_') : null;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
deleted file mode 100644
index b88ffce275a..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
+++ /dev/null
@@ -1,155 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.tensor.IndexedTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-
-import java.nio.ByteBuffer;
-import java.nio.DoubleBuffer;
-import java.nio.FloatBuffer;
-import java.nio.IntBuffer;
-import java.nio.LongBuffer;
-
-
-/**
- * Converts TensorFlow tensors into Vespa tensors.
- *
- * @author bratseth
- */
-public class TensorConverter {
-
- public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
- TensorType type = toVespaTensorType(tfTensor.shape());
- Values values = readValuesOf(tfTensor);
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
- for (int i = 0; i < values.size(); i++)
- builder.cellByDirectIndex(i, values.get(i));
- return builder.build();
- }
-
- private static TensorType toVespaTensorType(long[] shape) {
- TensorType.Builder b = new TensorType.Builder();
- int dimensionIndex = 0;
- for (long dimensionSize : shape) {
- if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
- b.indexed("d" + (dimensionIndex++), dimensionSize);
- }
- return b.build();
- }
-
- private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
- switch (tfTensor.dataType()) {
- case DOUBLE: return new DoubleValues(tfTensor);
- case FLOAT: return new FloatValues(tfTensor);
- case BOOL: return new BoolValues(tfTensor);
- case UINT8: return new IntValues(tfTensor);
- case INT32: return new IntValues(tfTensor);
- case INT64: return new LongValues(tfTensor);
- default:
- throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
- tfTensor.dataType() + " to a Vespa tensor");
- }
- }
-
- /** Allows reading values from buffers of various numeric types as bytes */
- private static abstract class Values {
-
- private final int size;
-
- protected Values(int size) {
- this.size = size;
- }
-
- abstract double get(int i);
-
- int size() { return size; }
-
- }
-
- private static class DoubleValues extends Values {
-
- private final DoubleBuffer values;
-
- DoubleValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = DoubleBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
-
- @Override
- double get(int i) {
- return values.get(i);
- }
-
- }
-
- private static class FloatValues extends Values {
-
- private final FloatBuffer values;
-
- FloatValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = FloatBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
-
- @Override
- double get(int i) {
- return values.get(i);
- }
-
- }
-
- private static class BoolValues extends Values {
-
- private final ByteBuffer values;
-
- BoolValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = ByteBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
-
- @Override
- double get(int i) {
- return values.get(i);
- }
-
- }
-
- private static class IntValues extends Values {
-
- private final IntBuffer values;
-
- IntValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = IntBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
-
- @Override
- double get(int i) {
- return values.get(i);
- }
-
- }
-
- private static class LongValues extends Values {
-
- private final LongBuffer values;
-
- LongValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = LongBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
-
- @Override
- double get(int i) {
- return values.get(i);
- }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index c97ee2b1514..7116d430502 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -2,10 +2,20 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.yolean.Exceptions;
import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
@@ -24,6 +34,7 @@ import java.util.stream.Collectors;
* Converts a saved TensorFlow model into a ranking expression and set of constants.
*
* @author bratseth
+ * @author lesters
*/
public class TensorFlowImporter {
@@ -57,196 +68,303 @@ public class TensorFlowImporter {
}
}
- private TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle model) {
- TensorFlowModel result = new TensorFlowModel();
+ /**
+ * Imports the TensorFlow graph by first importing the tensor types, then
+ * finding a suitable set of dimensions names for each
+ * placeholder/constant/variable, then importing the expressions.
+ */
+ private static TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle bundle) {
+ TensorFlowModel model = new TensorFlowModel();
+ OperationIndex index = new OperationIndex();
+
+ importSignatures(graph, model);
+ importNodes(graph, model, index);
+ findDimensionNames(model, index);
+ importExpressions(model, index, bundle);
+
+ // nodes with multiple outputs are calculated multiple times. consider adding macros for those.
+
+ reportWarnings(model, index);
+
+ return model;
+ }
+
+ private static void importSignatures(MetaGraphDef graph, TensorFlowModel model) {
for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
- TensorFlowModel.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName"
+ String signatureName = signatureEntry.getKey();
+ TensorFlowModel.Signature signature = model.signature(signatureName);
+
+ Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap();
+ for (Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) {
+ String inputName = input.getKey();
+ signature.input(inputName, namePartOf(input.getValue().getName()));
+ }
- importInputs(signatureEntry.getValue().getInputsMap(), signature);
- for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
+ Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap();
+ for (Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) {
String outputName = output.getKey();
- try {
- NodeDef node = getNode(namePartOf(output.getValue().getName()), graph.getGraphDef());
- Parameters params = createParameters(graph.getGraphDef(), model, result, signature, node, "");
-
- // Commonly, there are multiple paths through a TensorFlow graph, for instance for
- // training and testing/evaluation. Examples are dropout and batch norm. For Vespa
- // we are not concerned with training paths, so we can ignore non-supported operations
- // as long as they are on a path that will not be evaluated run time. Operations
- // that fail import will not have a value present in the optionals. However, the
- // final output node must have value present. It is an error if it does not.
-
- Optional<TypedTensorFunction> outputFunction = importNode(params);
- if (!outputFunction.isPresent()) {
- throw new IllegalArgumentException(signature.importWarnings().stream().collect(Collectors.joining("\n")));
- }
- signature.output(outputName, namePartOf(output.getValue().getName()));
- }
- catch (IllegalArgumentException e) {
- signature.skippedOutput(outputName, Exceptions.toMessageString(e));
- }
+ signature.output(outputName, namePartOf(output.getValue().getName()));
}
}
- return result;
}
- private void importInputs(Map<String, TensorInfo> inputInfoMap, TensorFlowModel.Signature signature) {
- inputInfoMap.forEach((key, value) -> {
- String argumentName = namePartOf(value.getName());
- TensorType argumentType = AttrValueConverter.toVespaTensorType(value.getTensorShape());
- // Arguments are (Placeholder) nodes, so not local to the signature:
- signature.owner().argument(argumentName, argumentType);
- signature.input(key, argumentName);
- });
+ private static boolean isSignatureInput(TensorFlowModel model, TensorFlowOperation operation) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String inputName : signature.inputs().values()) {
+ if (inputName.equals(operation.node().getName())) {
+ return true;
+ }
+ }
+ }
+ return false;
}
- /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
- private Optional<TypedTensorFunction> importNode(Parameters params) {
- String nodeName = params.node().getName();
- if (params.imported().containsKey(nodeName)) {
- return Optional.of(params.imported().get(nodeName));
+ private static boolean isSignatureOutput(TensorFlowModel model, TensorFlowOperation operation) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ if (outputName.equals(operation.node().getName())) {
+ return true;
+ }
+ }
}
+ return false;
+ }
- Optional<TypedTensorFunction> function = OperationMapper.map(params);
- if ( ! function.isPresent()) {
- return Optional.empty();
- }
- if ( ! controlDependenciesArePresent(params)) {
- return Optional.empty();
+ private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ importNode(outputName, graph.getGraphDef(), index);
+ }
}
- params.imported().put(nodeName, function.get());
+ }
- try {
- // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
- // will be used. We parse the TensorFunction here to convert it to a RankingExpression tree
- params.result().expression(nodeName,
- new RankingExpression(nodeName, function.get().function().toString()));
- return function;
+ private static TensorFlowOperation importNode(String name, GraphDef graph, OperationIndex index) {
+ if (index.alreadyImported(name)) {
+ return index.get(name);
}
- catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function.get().function() +
- " cannot be parsed as a ranking expression", e);
+ NodeDef node = getTensorFlowNodeFromGraph(namePartOf(name), graph);
+ List<TensorFlowOperation> inputs = importNodeInputs(node, graph, index);
+ TensorFlowOperation operation = OperationMapper.get(node, inputs, portPartOf(name));
+ index.put(name, operation);
+
+ List<TensorFlowOperation> controlInputs = importControlInputs(node, graph, index);
+ if (controlInputs.size() > 0) {
+ operation.setControlInputs(controlInputs);
}
- }
- private boolean controlDependenciesArePresent(Parameters params) {
- return params.node().getInputList().stream()
- .filter(TensorFlowImporter::isControlDependency)
- .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName))))
- .allMatch(Optional::isPresent);
+ return operation;
}
- private static boolean isControlDependency(String nodeName) {
- return nodeName.startsWith("^");
+ private static List<TensorFlowOperation> importNodeInputs(NodeDef node, GraphDef graph, OperationIndex index) {
+ return node.getInputList().stream()
+ .filter(name -> ! isControlDependency(name))
+ .map(name -> importNode(name, graph, index))
+ .collect(Collectors.toList());
}
- private List<Optional<TypedTensorFunction>> importArguments(Parameters params) {
- return params.node().getInputList().stream()
- .filter(nodeName -> !isControlDependency(nodeName))
- .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName))))
+ private static List<TensorFlowOperation> importControlInputs(NodeDef node, GraphDef graph, OperationIndex index) {
+ return node.getInputList().stream()
+ .filter(name -> isControlDependency(name))
+ .map(name -> importNode(name, graph, index))
.collect(Collectors.toList());
}
- private NodeDef getNode(String name, GraphDef graph) {
- return graph.getNodeList().stream()
- .filter(node -> node.getName().equals(name))
- .findFirst()
- .orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'"));
+ private static boolean isControlDependency(String name) {
+ return name.startsWith("^");
}
- /**
- * A method signature input and output has the form name:index.
- * This returns the name part without the index.
- */
- private static String namePartOf(String name) {
- name = name.startsWith("^") ? name.substring(1) : name;
- return name.split(":")[0];
+ /** Find dimension names to avoid excessive renaming while evaluating the model. */
+ private static void findDimensionNames(TensorFlowModel model, OperationIndex index) {
+ DimensionRenamer renamer = new DimensionRenamer();
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String output : signature.outputs().values()) {
+ addDimensionNameConstraints(index.get(output), renamer);
+ }
+ }
+ renamer.solve();
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String output : signature.outputs().values()) {
+ renameDimensions(index.get(output), renamer);
+ }
+ }
}
- /**
- * This return the index part. Indexes are used for nodes with
- * multiple outputs.
- */
- private static String indexPartOf(String name) {
- int i = name.indexOf(":");
- return i < 0 ? "" : name.substring(i + 1);
+ private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.addDimensionNameConstraints(renamer);
+ }
}
+ private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.renameDimensions(renamer);
+ }
+ }
- private Parameters createParameters(GraphDef graph,
- SavedModelBundle model,
- TensorFlowModel result,
- TensorFlowModel.Signature signature,
- NodeDef node,
- String port) {
- return new Parameters(this, graph, model, result, signature, new HashMap<>(), node, port);
+ private static void importExpressions(TensorFlowModel model, OperationIndex index, SavedModelBundle bundle) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ try {
+ Optional<TensorFunction> function = importExpression(index.get(outputName), model, bundle);
+ if (!function.isPresent()) {
+ signature.skippedOutput(outputName, "No valid output function could be found.");
+ }
+ }
+ catch (IllegalArgumentException e) {
+ signature.skippedOutput(outputName, Exceptions.toMessageString(e));
+ }
+ }
+ }
}
- /** Parameter object to hold important data while importing */
- static final class Parameters {
+ private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) {
+ if (!operation.type().isPresent()) {
+ return Optional.empty();
+ }
+ if (operation.isConstant()) {
+ return importConstant(model, operation, bundle);
+ }
- private final TensorFlowImporter owner;
- private final GraphDef graph;
- private final SavedModelBundle model;
- private final TensorFlowModel result;
- private final TensorFlowModel.Signature signature;
- private final Map<String, TypedTensorFunction> imported;
- private final NodeDef node;
- private final String port;
+ importInputExpressions(operation, model, bundle);
+ importRankingExpression(model, operation);
+ importInputExpression(model, operation);
+ importMacroExpression(model, operation);
- private Parameters(TensorFlowImporter owner,
- GraphDef graph,
- SavedModelBundle model,
- TensorFlowModel result,
- TensorFlowModel.Signature signature,
- Map<String, TypedTensorFunction> imported,
- NodeDef node,
- String port) {
- this.owner = owner;
- this.graph = graph;
- this.model = model;
- this.result = result;
- this.signature = signature;
- this.imported = imported;
- this.node = node;
- this.port = port;
- }
+ return operation.function();
+ }
- GraphDef graph() {
- return this.graph;
+ private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) {
+ operation.inputs().forEach(input -> importExpression(input, model, bundle));
+ }
+
+ private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) {
+ if (operation.macro().isPresent()) {
+ model.macro(operation.vespaName(), operation.macro().get());
}
+ }
- SavedModelBundle model() {
- return this.model;
+ private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation, SavedModelBundle bundle) {
+ String name = operation.vespaName();
+ if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
+ return operation.function();
}
- TensorFlowModel result() {
- return this.result;
+ Tensor tensor;
+ if (operation.getConstantValue().isPresent()) {
+ Value value = operation.getConstantValue().get();
+ if ( ! (value instanceof TensorValue)) {
+ return operation.function(); // scalar values are inserted directly into the expression
+ }
+ tensor = value.asTensor();
+ } else {
+ Session.Runner fetched = bundle.session().runner().fetch(operation.node().getName());
+ List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
+ if (importedTensors.size() != 1) {
+ throw new IllegalStateException("Expected 1 tensor from fetching " + operation.node().getName() + ", but got " +
+ importedTensors.size());
+ }
+ // Here we use the type from the operation, which will have correct dimension names after name resolving
+ tensor = TensorConverter.toVespaTensor(importedTensors.get(0), operation.type().get());
+ operation.setConstantValue(new TensorValue(tensor));
}
- TensorFlowModel.Signature signature() {
- return this.signature;
+ if (tensor.type().rank() == 0 || tensor.size() <= 1) {
+ model.smallConstant(operation.vespaName(), tensor);
+ } else {
+ model.largeConstant(operation.vespaName(), tensor);
}
+ return operation.function();
+ }
+
+ private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) {
+ if (operation.function().isPresent()) {
+ String name = operation.node().getName();
+ if (!model.expressions().containsKey(operation.node().getName())) {
+ TensorFunction function = operation.function().get();
+
+ // Make sure output adheres to standard naming convention
+ if (isSignatureOutput(model, operation)) {
+ OrderedTensorType operationType = operation.type().get();
+ OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node());
+ if ( ! operationType.equals(standardNamingType)) {
+ List<String> renameFrom = operationType.dimensionNames();
+ List<String> renameTo = standardNamingType.dimensionNames();
+ function = new Rename(function, renameFrom, renameTo);
+ }
+ }
- Map<String, TypedTensorFunction> imported() {
- return this.imported;
+ try {
+ // We add all intermediate nodes imported as separate expressions. Only
+ // those referenced in a signature output will be used. We parse the
+ // TensorFunction here to convert it to a RankingExpression tree.
+ model.expression(name, new RankingExpression(name, function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Tensorflow function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
}
+ }
- NodeDef node() {
- return node;
+ private static void importInputExpression(TensorFlowModel model, TensorFlowOperation operation) {
+ if (operation.isInput() && isSignatureInput(model, operation)) {
+ // All inputs must have dimensions with standard naming convention: d0, d1, ...
+ OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node());
+ model.argument(operation.node().getName(), standardNamingConvention.type());
+ model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
}
+ }
- String port() {
- return port;
+ private static void reportWarnings(TensorFlowModel model, OperationIndex index) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String output : signature.outputs().values()) {
+ reportWarnings(index.get(output), signature);
+ }
}
+ }
- Parameters copy(NodeDef node, String port) {
- return new Parameters(this.owner, this.graph, this.model, this.result, this.signature, this.imported, node, port);
+ private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) {
+ for (String warning : operation.warnings()) {
+ signature.importWarning(warning);
}
+ }
- List<Optional<TypedTensorFunction>> inputs() {
- return owner.importArguments(this);
+ private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) {
+ for (NodeDef node : graph.getNodeList()) {
+ if (node.getName().equals(name)) {
+ return node;
+ }
}
+ throw new IllegalArgumentException("Could not find node '" + name + "'");
+ }
+
+ /**
+ * A method signature input and output has the form name:index.
+ * This returns the name part without the index.
+ */
+ private static String namePartOf(String name) {
+ name = name.startsWith("^") ? name.substring(1) : name;
+ return name.split(":")[0];
+ }
+
+ /**
+ * This return the output port part. Indexes are used for nodes with
+ * multiple outputs.
+ */
+ private static int portPartOf(String name) {
+ int i = name.indexOf(":");
+ return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
+ }
+
+
+ private static class OperationIndex {
+ private final Map<String, TensorFlowOperation> index = new HashMap<>();
+ public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); }
+ public TensorFlowOperation get(String key) { return index.get(key); }
+ public boolean alreadyImported(String key) { return index.containsKey(key); }
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
deleted file mode 100644
index 600225bfe76..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-/**
- * A tensor function returning a specific tensor type
- *
- * @author bratseth
- */
-final class TypedTensorFunction {
-
- private final TensorType type;
- private final TensorFunction function;
-
- public TypedTensorFunction(TensorType type, TensorFunction function) {
- this.type = type;
- this.function = function;
- }
-
- public TensorType type() {
- return type;
- }
-
- public TensorFunction function() {
- return function;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java
new file mode 100644
index 00000000000..c1665d066a4
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java
@@ -0,0 +1,210 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+/**
+ * A constraint satisfier to find suitable dimension names to reduce the
+ * amount of necessary renaming during evaluation of an imported model.
+ *
+ * @author lesters
+ */
+public class DimensionRenamer {
+
+ private final String dimensionPrefix;
+ private final Map<String, List<Integer>> variables = new HashMap<>();
+ private final Map<Arc, Constraint> constraints = new HashMap<>();
+ private final Map<String, Integer> renames = new HashMap<>();
+
+ private int iterations = 0;
+
+ public DimensionRenamer() {
+ this("d");
+ }
+
+ public DimensionRenamer(String dimensionPrefix) {
+ this.dimensionPrefix = dimensionPrefix;
+ }
+
+ /**
+ * Add a dimension name variable.
+ */
+ public void addDimension(String name) {
+ variables.computeIfAbsent(name, d -> new ArrayList<>());
+ }
+
+ /**
+ * Add a constraint between dimension names.
+ */
+ public void addConstraint(String from, String to, Constraint pred, TensorFlowOperation operation) {
+ Arc arc = new Arc(from, to, operation);
+ Arc opposite = arc.opposite();
+ constraints.put(arc, pred);
+ constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric
+ }
+
+ /**
+ * Retrieve resulting name of dimension after solving for constraints.
+ */
+ public Optional<String> dimensionNameOf(String name) {
+ if (!renames.containsKey(name)) {
+ return Optional.empty();
+ }
+ return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name)));
+ }
+
+ /**
+ * Perform iterative arc consistency until we have found a solution. After
+ * an initial iteration, the variables (dimensions) will have multiple
+ * valid values. Find a single valid assignment by iteratively locking one
+ * dimension after another, and running the arc consistency algorithm
+ * multiple times.
+ *
+ * This requires having constraints that result in an absolute ordering:
+ * equals, lesserThan and greaterThan do that, but adding notEquals does
+ * not typically result in a guaranteed ordering. If that is needed, the
+ * algorithm below needs to be adapted with a backtracking (tree) search
+ * to find solutions.
+ */
+ public void solve(int maxIterations) {
+ initialize();
+
+ // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
+
+ for (String dimension : variables.keySet()) {
+ List<Integer> values = variables.get(dimension);
+ if (values.size() > 1) {
+ if (!ac3()) {
+ throw new IllegalArgumentException("Dimension renamer unable to find a solution.");
+ }
+ values.sort(Integer::compare);
+ variables.put(dimension, Collections.singletonList(values.get(0)));
+ }
+ renames.put(dimension, variables.get(dimension).get(0));
+ if (iterations > maxIterations) {
+ throw new IllegalArgumentException("Dimension renamer unable to find a solution within " +
+ maxIterations + " iterations");
+ }
+ }
+
+ // Todo: handle failure more gracefully:
+ // If a solution can't be found, look at the operation node in the arc
+ // with the most remaining constraints, and inject a rename operation.
+ // Then run this algorithm again.
+ }
+
+ public void solve() {
+ solve(100000);
+ }
+
+ private void initialize() {
+ for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) {
+ List<Integer> values = variable.getValue();
+ for (int i = 0; i < variables.size(); ++i) {
+ values.add(i); // invariant: values are in increasing order
+ }
+ }
+ }
+
+ private boolean ac3() {
+ Deque<Arc> workList = new ArrayDeque<>(constraints.keySet());
+ while (!workList.isEmpty()) {
+ Arc arc = workList.pop();
+ iterations += 1;
+ if (revise(arc)) {
+ if (variables.get(arc.from).size() == 0) {
+ return false; // no solution found
+ }
+ for (Arc constraint : constraints.keySet()) {
+ if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) {
+ workList.add(constraint);
+ }
+ }
+ }
+ }
+ return true;
+ }
+
+ private boolean revise(Arc arc) {
+ boolean revised = false;
+ for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) {
+ Integer from = fromIterator.next();
+ boolean satisfied = false;
+ for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) {
+ Integer to = toIterator.next();
+ if (constraints.get(arc).test(from, to)) {
+ satisfied = true;
+ }
+ }
+ if (!satisfied) {
+ fromIterator.remove();
+ revised = true;
+ }
+ }
+ return revised;
+ }
+
+ public interface Constraint {
+ boolean test(Integer x, Integer y);
+ }
+
+ public static boolean equals(Integer x, Integer y) {
+ return Objects.equals(x, y);
+ }
+
+ public static boolean lesserThan(Integer x, Integer y) {
+ return x < y;
+ }
+
+ public static boolean greaterThan(Integer x, Integer y) {
+ return x > y;
+ }
+
+ private static class Arc {
+
+ private final String from;
+ private final String to;
+ private final TensorFlowOperation operation;
+
+ Arc(String from, String to, TensorFlowOperation operation) {
+ this.from = from;
+ this.to = to;
+ this.operation = operation;
+ }
+
+ Arc opposite() {
+ return new Arc(to, from, operation);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(from, to);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null || !(obj instanceof Arc)) {
+ return false;
+ }
+ Arc other = (Arc) obj;
+ return Objects.equals(from, other.from) && Objects.equals(to, other.to);
+ }
+
+ @Override
+ public String toString() {
+ return String.format("%s -> %s", from, to);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
new file mode 100644
index 00000000000..0fe73fad8ce
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
@@ -0,0 +1,108 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ExpandDims;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Identity;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Join;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Map;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Matmul;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Mean;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Merge;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.NoOp;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Placeholder;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.PlaceholderWithDefault;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Reshape;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Select;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Shape;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Squeeze;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Switch;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+/**
+ * Maps from TensorFlow operations to Vespa operations.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+public class OperationMapper {
+
+ public static TensorFlowOperation get(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ switch (node.getOp().toLowerCase()) {
+ /*
+ * array ops
+ */
+ case "const": return new Const(node, inputs, port);
+ case "expanddims": return new ExpandDims(node, inputs, port);
+ case "identity": return new Identity(node, inputs, port);
+ case "placeholder": return new Placeholder(node, inputs, port);
+ case "placeholderwithdefault": return new PlaceholderWithDefault(node, inputs, port);
+ case "reshape": return new Reshape(node, inputs, port);
+ case "shape": return new Shape(node, inputs, port);
+ case "squeeze": return new Squeeze(node, inputs, port);
+
+ /*
+ * control flow
+ */
+ case "merge": return new Merge(node, inputs, port);
+ case "switch": return new Switch(node, inputs, port);
+
+ /*
+ * math ops
+ */
+ case "add": return new Join(node, inputs, port, ScalarFunctions.add());
+ case "add_n": return new Join(node, inputs, port, ScalarFunctions.add());
+ case "acos": return new Map(node, inputs, port, ScalarFunctions.acos());
+ case "div": return new Join(node, inputs, port, ScalarFunctions.divide());
+ case "realdiv": return new Join(node, inputs, port, ScalarFunctions.divide());
+ case "floor": return new Map(node, inputs, port, ScalarFunctions.floor());
+ case "matmul": return new Matmul(node, inputs, port);
+ case "maximum": return new Join(node, inputs, port, ScalarFunctions.max());
+ case "mean": return new Mean(node, inputs, port);
+ case "reducemean": return new Mean(node, inputs, port);
+ case "mul": return new Join(node, inputs, port, ScalarFunctions.multiply());
+ case "multiply": return new Join(node, inputs, port, ScalarFunctions.multiply());
+ case "rsqrt": return new Map(node, inputs, port, ScalarFunctions.rsqrt());
+ case "select": return new Select(node, inputs, port);
+ case "where3": return new Select(node, inputs, port);
+ case "sigmoid": return new Map(node, inputs, port, ScalarFunctions.sigmoid());
+ case "squareddifference": return new Join(node, inputs, port, ScalarFunctions.squareddifference());
+ case "sub": return new Join(node, inputs, port, ScalarFunctions.subtract());
+ case "subtract": return new Join(node, inputs, port, ScalarFunctions.subtract());
+
+ /*
+ * nn ops
+ */
+ case "biasadd": return new Join(node, inputs, port, ScalarFunctions.add());
+ case "elu": return new Map(node, inputs, port, ScalarFunctions.elu());
+ case "relu": return new Map(node, inputs, port, ScalarFunctions.relu());
+ case "selu": return new Map(node, inputs, port, ScalarFunctions.selu());
+
+ /*
+ * random ops
+ */
+
+ /*
+ * state ops
+ */
+ case "variable": return new Variable(node, inputs, port);
+ case "variablev2": return new Variable(node, inputs, port);
+
+ /*
+ * evaluation no-ops
+ */
+ case "stopgradient":return new Identity(node, inputs, port);
+ case "noop": return new NoOp(node, inputs, port);
+ }
+ return new NoOp(node, inputs, port);
+ }
+
+}
+
+
+
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
new file mode 100644
index 00000000000..3742e443a06
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
@@ -0,0 +1,237 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+
+import com.yahoo.tensor.TensorType;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.TensorShapeProto;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * A Vespa tensor type is ordered by the lexicographical ordering of dimension
+ * names. TensorFlow tensors have an explicit ordering of their dimensions.
+ * During import, we need to track the Vespa dimension that matches the
+ * corresponding TensorFlow dimension as the ordering can change after
+ * dimension renaming. That is the purpose of this class.
+ *
+ * @author lesters
+ */
+public class OrderedTensorType {
+
+ private final TensorType type;
+ private final List<TensorType.Dimension> dimensions;
+
+ private final long[] innerSizesTensorFlow;
+ private final long[] innerSizesVespa;
+ private final int[] dimensionMap;
+
+ private OrderedTensorType(List<TensorType.Dimension> dimensions) {
+ this.dimensions = Collections.unmodifiableList(dimensions);
+ this.type = new TensorType.Builder(dimensions).build();
+ this.innerSizesTensorFlow = new long[dimensions.size()];
+ this.innerSizesVespa = new long[dimensions.size()];
+ this.dimensionMap = createDimensionMap();
+ }
+
+ public TensorType type() {
+ return this.type;
+ }
+
+ public List<TensorType.Dimension> dimensions() {
+ return dimensions;
+ }
+
+ public List<String> dimensionNames() {
+ return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList());
+ }
+
+ private int[] createDimensionMap() {
+ int numDimensions = dimensions.size();
+ if (numDimensions == 0) {
+ return null;
+ }
+ innerSizesTensorFlow[numDimensions - 1] = 1;
+ innerSizesVespa[numDimensions - 1] = 1;
+ for (int i = numDimensions - 1; --i >= 0; ) {
+ innerSizesTensorFlow[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesTensorFlow[i+1];
+ innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1];
+ }
+ int[] mapping = new int[numDimensions];
+ for (int i = 0; i < numDimensions; ++i) {
+ TensorType.Dimension dim1 = dimensions().get(i);
+ for (int j = 0; j < numDimensions; ++j) {
+ TensorType.Dimension dim2 = type.dimensions().get(j);
+ if (dim1.equals(dim2)) {
+ mapping[i] = j;
+ break;
+ }
+ }
+ }
+ return mapping;
+ }
+
+ /**
+ * When dimension ordering between Vespa and TensorFlow differs, i.e.
+ * after dimension renaming, use the dimension map to read in values
+ * so that they are correctly laid out in memory for Vespa.
+ * Used when importing tensors from TensorFlow.
+ */
+ public int toDirectIndex(int index) {
+ if (dimensions.size() == 0) {
+ return 0;
+ }
+ if (dimensionMap == null) {
+ throw new IllegalArgumentException("Dimension map is not available");
+ }
+ int directIndex = 0;
+ long rest = index;
+ for (int i = 0; i < dimensions.size(); ++i) {
+ long address = rest / innerSizesTensorFlow[i];
+ directIndex += innerSizesVespa[dimensionMap[i]] * address;
+ rest %= innerSizesTensorFlow[i];
+ }
+ return directIndex;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null || !(obj instanceof OrderedTensorType)) {
+ return false;
+ }
+ OrderedTensorType other = (OrderedTensorType) obj;
+ if (dimensions.size() != dimensions.size()) {
+ return false;
+ }
+ List<TensorType.Dimension> thisDimensions = this.dimensions();
+ List<TensorType.Dimension> otherDimensions = other.dimensions();
+ for (int i = 0; i < thisDimensions.size(); ++i) {
+ if (!thisDimensions.get(i).equals(otherDimensions.get(i))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ public static void verifyType(NodeDef node, OrderedTensorType type) {
+ if (type == null) {
+ return;
+ }
+ TensorShapeProto shape = tensorFlowShape(node);
+ if (shape != null && type.type != null) {
+ if (shape.getDimCount() != type.type.rank()) {
+ throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
+ "does not match Vespa shape");
+ }
+ for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions.size(); ++tensorFlowIndex) {
+ int vespaIndex = type.dimensionMap[tensorFlowIndex];
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
+ TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
+ if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
+ throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
+ "does not match Vespa dimensions");
+ }
+ }
+ }
+ }
+
+ private static TensorShapeProto tensorFlowShape(NodeDef node) {
+ AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
+ if (attrValueList == null) {
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "does not exist");
+ }
+ if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "is not of expected type");
+ }
+ List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
+ return shapeList.get(0); // support multiple outputs?
+ }
+
+ public static OrderedTensorType rename(OrderedTensorType type, DimensionRenamer renamer) {
+ List<TensorType.Dimension> renamedDimensions = new ArrayList<>(type.dimensions.size());
+ for (TensorType.Dimension dimension : type.dimensions) {
+ String oldName = dimension.name();
+ Optional<String> newName = renamer.dimensionNameOf(oldName);
+ if (!newName.isPresent())
+ return type; // presumably, already renamed
+ TensorType.Dimension.Type dimensionType = dimension.type();
+ if (dimensionType == TensorType.Dimension.Type.indexedBound) {
+ renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get()));
+ } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) {
+ renamedDimensions.add(TensorType.Dimension.indexed(newName.get()));
+ } else if (dimensionType == TensorType.Dimension.Type.mapped) {
+ renamedDimensions.add(TensorType.Dimension.mapped(newName.get()));
+ }
+ }
+ return new OrderedTensorType(renamedDimensions);
+ }
+
+ public static OrderedTensorType fromTensorFlowType(NodeDef node) {
+ return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
+ Builder builder = new Builder(node);
+ TensorShapeProto shape = tensorFlowShape(node);
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
+ if (tensorFlowDimension.getSize() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+ public static class Builder {
+
+ private final TensorShapeProto shape;
+ private final List<TensorType.Dimension> dimensions;
+
+ public Builder(NodeDef node) {
+ this.shape = tensorFlowShape(node);
+ this.dimensions = new ArrayList<>(shape.getDimCount());
+ }
+
+ public Builder add(TensorType.Dimension vespaDimension) {
+ int index = dimensions.size();
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(index);
+ long size = tensorFlowDimension.getSize();
+ if (size >= 0) {
+ if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
+ throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
+ "dimension types");
+ }
+ if (!vespaDimension.size().isPresent()) {
+ throw new IllegalArgumentException("Tensor dimension is indexed bound but does " +
+ "not have a size");
+ }
+ if (vespaDimension.size().get() != size) {
+ throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
+ "dimension sizes. TensorFlow: " + size + " Vespa: " + vespaDimension.size().get());
+ }
+ } else {
+ if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
+ throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
+ "dimension types");
+ }
+ }
+ this.dimensions.add(vespaDimension);
+ return this;
+ }
+
+ public OrderedTensorType build() {
+ return new OrderedTensorType(dimensions);
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java
new file mode 100644
index 00000000000..3f55e622fdf
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java
@@ -0,0 +1,224 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.tensorflow.framework.TensorProto;
+
+import java.nio.ByteBuffer;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.nio.LongBuffer;
+
+
+/**
+ * Converts TensorFlow tensors into Vespa tensors.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+public class TensorConverter {
+
+ public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
+ return toVespaTensor(tfTensor, "d");
+ }
+
+ public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
+ TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix);
+ Values values = readValuesOf(tfTensor);
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
+ for (int i = 0; i < values.size(); i++)
+ builder.cellByDirectIndex(i, values.get(i));
+ return builder.build();
+ }
+
+ public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) {
+ Values values = readValuesOf(tfTensor);
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type());
+ for (int i = 0; i < values.size(); i++) {
+ builder.cellByDirectIndex(type.toDirectIndex(i), values.get(i));
+ }
+ return builder.build();
+ }
+
+ public static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) {
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
+ Values values = readValuesOf(tensorProto);
+ for (int i = 0; i < values.size(); ++i) {
+ builder.cellByDirectIndex(i, values.get(i));
+ }
+ return builder.build();
+ }
+
+ private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) {
+ TensorType.Builder b = new TensorType.Builder();
+ int dimensionIndex = 0;
+ for (long dimensionSize : shape) {
+ if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
+ b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize);
+ }
+ return b.build();
+ }
+
+ public static Long tensorSize(TensorType type) {
+ Long size = 1L;
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ size *= dimensionSize(dimension);
+ }
+ return size;
+ }
+
+ public static Long dimensionSize(TensorType.Dimension dim) {
+ return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
+ }
+
+ private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
+ switch (tfTensor.dataType()) {
+ case DOUBLE: return new DoubleValues(tfTensor);
+ case FLOAT: return new FloatValues(tfTensor);
+ case BOOL: return new BoolValues(tfTensor);
+ case UINT8: return new IntValues(tfTensor);
+ case INT32: return new IntValues(tfTensor);
+ case INT64: return new LongValues(tfTensor);
+ }
+ throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
+ tfTensor.dataType() + " to a Vespa tensor");
+ }
+
+ private static Values readValuesOf(TensorProto tensorProto) {
+ switch (tensorProto.getDtype()) {
+ case DT_BOOL:
+ return new ProtoBoolValues(tensorProto);
+ case DT_HALF:
+ return new ProtoHalfValues(tensorProto);
+ case DT_INT16:
+ case DT_INT32:
+ return new ProtoIntValues(tensorProto);
+ case DT_INT64:
+ return new ProtoInt64Values(tensorProto);
+ case DT_FLOAT:
+ return new ProtoFloatValues(tensorProto);
+ case DT_DOUBLE:
+ return new ProtoDoubleValues(tensorProto);
+ }
+ throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
+ }
+
+ /** Allows reading values from buffers of various numeric types as bytes */
+ private static abstract class Values {
+ abstract double get(int i);
+ abstract int size();
+ }
+
+ private static abstract class TensorFlowValues extends Values {
+ private final int size;
+ TensorFlowValues(int size) {
+ this.size = size;
+ }
+ @Override int size() { return this.size; }
+ }
+
+ private static class DoubleValues extends TensorFlowValues {
+ private final DoubleBuffer values;
+ DoubleValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = DoubleBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+ @Override double get(int i) {
+ return values.get(i);
+ }
+ }
+
+ private static class FloatValues extends TensorFlowValues {
+ private final FloatBuffer values;
+ FloatValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = FloatBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+ @Override double get(int i) {
+ return values.get(i);
+ }
+ }
+
+ private static class BoolValues extends TensorFlowValues {
+ private final ByteBuffer values;
+ BoolValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = ByteBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+ @Override double get(int i) {
+ return values.get(i);
+ }
+ }
+
+ private static class IntValues extends TensorFlowValues {
+ private final IntBuffer values;
+ IntValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = IntBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+ @Override double get(int i) {
+ return values.get(i);
+ }
+ }
+
+ private static class LongValues extends TensorFlowValues {
+ private final LongBuffer values;
+ LongValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = LongBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+ @Override double get(int i) {
+ return values.get(i);
+ }
+ }
+
+ private static abstract class ProtoValues extends Values {
+ protected final TensorProto tensorProto;
+ protected ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; }
+ }
+
+ private static class ProtoBoolValues extends ProtoValues {
+ ProtoBoolValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; }
+ @Override int size() { return tensorProto.getBoolValCount(); }
+ }
+
+ private static class ProtoHalfValues extends ProtoValues {
+ ProtoHalfValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getHalfVal(i); }
+ @Override int size() { return tensorProto.getHalfValCount(); }
+ }
+
+ private static class ProtoIntValues extends ProtoValues {
+ ProtoIntValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getIntVal(i); }
+ @Override int size() { return tensorProto.getIntValCount(); }
+ }
+
+ private static class ProtoInt64Values extends ProtoValues {
+ ProtoInt64Values(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getInt64Val(i); }
+ @Override int size() { return tensorProto.getInt64ValCount(); }
+ }
+
+ private static class ProtoFloatValues extends ProtoValues {
+ ProtoFloatValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getFloatVal(i); }
+ @Override int size() { return tensorProto.getFloatValCount(); }
+ }
+
+ private static class ProtoDoubleValues extends ProtoValues {
+ ProtoDoubleValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getDoubleVal(i); }
+ @Override int size() { return tensorProto.getDoubleValCount(); }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
new file mode 100644
index 00000000000..7decef51ab7
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
@@ -0,0 +1,93 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Optional;
+
+public class Const extends TensorFlowOperation {
+
+ public Const(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ setConstantValue(value());
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ }
+
+ @Override
+ public Optional<TensorFunction> function() {
+ if (function == null) {
+ function = lazyGetFunction();
+ }
+ return Optional.ofNullable(function);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ ExpressionNode expressionNode;
+ if (type.type().rank() == 0 && getConstantValue().isPresent()) {
+ expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue());
+ } else {
+ expressionNode = new ReferenceNode("constant(\"" + vespaName() + "\")");
+ }
+ return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ setConstantValue(value());
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+ private Value value() {
+ if (!node.getAttrMap().containsKey("value")) {
+ throw new IllegalArgumentException("Node '" + node.getName() + "' of type " +
+ "const has missing 'value' attribute");
+ }
+ AttrValue attrValue = node.getAttrMap().get("value");
+ if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
+ return new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type().get().type()));
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.B) {
+ return new BooleanValue(attrValue.getB());
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.I) {
+ return new DoubleValue(attrValue.getI());
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.F) {
+ return new DoubleValue(attrValue.getF());
+ }
+ throw new IllegalArgumentException("Requesting value of constant in " +
+ node.getName() + " but type is not recognized.");
+ }
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java
new file mode 100644
index 00000000000..c1ad21f41d8
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java
@@ -0,0 +1,107 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+public class ExpandDims extends TensorFlowOperation {
+
+ private List<String> expandDimensions;
+
+ public ExpandDims(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+
+ TensorFlowOperation axisOperation = inputs().get(1);
+ if (!axisOperation.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " +
+ "axis must be a constant.");
+ }
+ Tensor axis = axisOperation.getConstantValue().get().asTensor();
+ if (axis.type().rank() != 0) {
+ throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " +
+ "axis argument must be a scalar.");
+ }
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ int dimensionToInsert = (int)axis.asDouble();
+ if (dimensionToInsert < 0) {
+ dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
+ }
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
+ expandDimensions = new ArrayList<>();
+ int dimensionIndex = 0;
+ for (TensorType.Dimension dimension : inputType.dimensions()) {
+ if (dimensionIndex == dimensionToInsert) {
+ String name = String.format("%s_%d", vespaName(), dimensionIndex);
+ expandDimensions.add(name);
+ typeBuilder.add(TensorType.Dimension.indexed(name, 1L));
+ }
+ typeBuilder.add(dimension);
+ dimensionIndex++;
+ }
+
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(2)) {
+ return null;
+ }
+
+ // multiply with a generated tensor created from the reduced dimensions
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (String name : expandDimensions) {
+ typeBuilder.indexed(name, 1);
+ }
+ TensorType generatedType = typeBuilder.build();
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
+ Generate generatedFunction = new Generate(generatedType,
+ new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
+ return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ List<String> renamedDimensions = new ArrayList<>(expandDimensions.size());
+ for (String name : expandDimensions) {
+ Optional<String> newName = renamer.dimensionNameOf(name);
+ if (!newName.isPresent()) {
+ return; // presumably, already renamed
+ }
+ renamedDimensions.add(newName.get());
+ }
+ expandDimensions = renamedDimensions;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
new file mode 100644
index 00000000000..d79707a42e6
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
@@ -0,0 +1,30 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+public class Identity extends TensorFlowOperation {
+
+ public Identity(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1))
+ return null;
+ return inputs.get(0).type().orElse(null);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1))
+ return null;
+ return inputs.get(0).function().orElse(null);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
new file mode 100644
index 00000000000..aa27ba2684d
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
@@ -0,0 +1,79 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.function.DoubleBinaryOperator;
+
+public class Join extends TensorFlowOperation {
+
+ private final DoubleBinaryOperator operator;
+
+ public Join(NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleBinaryOperator operator) {
+ super(node, inputs, port);
+ this.operator = operator;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType a = inputs.get(0).type().get();
+ OrderedTensorType b = inputs.get(1).type().get();
+ OrderedTensorType out = a.type().rank() >= b.type().rank() ? a : b;
+ return out;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ Optional<TensorFunction> aFunction = inputs.get(0).function();
+ Optional<TensorFunction> bFunction = inputs.get(1).function();
+ if (!aFunction.isPresent() || !bFunction.isPresent()) {
+ return null;
+ }
+
+ // The dimension renaming below takes care of broadcasting.
+
+ return new com.yahoo.tensor.functions.Join(aFunction.get(), bFunction.get(), operator);
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(2)) {
+ return;
+ }
+
+ // Well now we have potentially entered the wonderful world of "broadcasting"
+ // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ // I'm not able to extract from that any unambiguous specification of which dimensions
+ // should be "stretched" when the tensor do not have the same number of dimensions.
+ // From trying this with TensorFlow it appears that the second tensor is matched to the
+ // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true.
+ // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first).
+
+ TensorType a = inputs.get(0).type().get().type();
+ TensorType b = inputs.get(1).type().get().type();
+ if (a.rank() < b.rank()) {
+ TensorType temp = a;
+ a = b;
+ b = temp;
+ }
+ int sizeDifference = a.rank() - b.rank();
+ for (int i = 0; i < b.rank(); ++i) {
+ String bDim = b.dimensions().get(i).name();
+ String aDim = a.dimensions().get(i + sizeDifference).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java
new file mode 100644
index 00000000000..105d65b3d69
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java
@@ -0,0 +1,38 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.function.DoubleUnaryOperator;
+
+public class Map extends TensorFlowOperation {
+
+ private final DoubleUnaryOperator operator;
+
+ public Map(NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleUnaryOperator operator) {
+ super(node, inputs, port);
+ this.operator = operator;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) {
+ return null;
+ }
+ return inputs.get(0).type().get();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1)) {
+ return null;
+ }
+ Optional<TensorFunction> input = inputs.get(0).function();
+ return new com.yahoo.tensor.functions.Map(input.get(), operator);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java
new file mode 100644
index 00000000000..ac4f78653d6
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java
@@ -0,0 +1,74 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Optional;
+
+public class Matmul extends TensorFlowOperation {
+
+ public Matmul(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
+ typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
+ typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType aType = inputs.get(0).type().get();
+ OrderedTensorType bType = inputs.get(1).type().get();
+ if (aType.type().rank() < 2 || bType.type().rank() < 2)
+ throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
+ if (aType.type().rank() != bType.type().rank())
+ throw new IllegalArgumentException("Tensors in matmul must have the same rank");
+
+ Optional<TensorFunction> aFunction = inputs.get(0).function();
+ Optional<TensorFunction> bFunction = inputs.get(1).function();
+ if (!aFunction.isPresent() || !bFunction.isPresent()) {
+ return null;
+ }
+ return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(2)) {
+ return;
+ }
+ List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
+ List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
+
+ String aDim0 = aDimensions.get(0).name();
+ String aDim1 = aDimensions.get(1).name();
+ String bDim0 = bDimensions.get(0).name();
+ String bDim1 = bDimensions.get(1).name();
+
+ // The second dimension of a should have the same name as the first dimension of b
+ renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this);
+
+ // The first dimension of a should have a different name than the second dimension of b
+ renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this);
+
+ // For efficiency, the dimensions to join over should be innermost - soft constraint
+ renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
+ renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
new file mode 100644
index 00000000000..dfe0796d9b8
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
@@ -0,0 +1,112 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Optional;
+
+public class Mean extends TensorFlowOperation {
+
+ private List<String> reduceDimensions;
+
+ public Mean(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ TensorFlowOperation reductionIndices = inputs.get(1);
+ if (!reductionIndices.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("Mean in " + node.getName() + ": " +
+ "reduction indices must be a constant.");
+ }
+ Tensor indices = reductionIndices.getConstantValue().get().asTensor();
+ reduceDimensions = new ArrayList<>();
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ for (Iterator<Tensor.Cell> cellIterator = indices.cellIterator(); cellIterator.hasNext();) {
+ Tensor.Cell cell = cellIterator.next();
+ int dimensionIndex = cell.getValue().intValue();
+ if (dimensionIndex < 0) {
+ dimensionIndex = inputType.dimensions().size() - dimensionIndex;
+ }
+ reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
+ }
+ return reducedType(inputType, shouldKeepDimensions());
+ }
+
+ // todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity.
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ TensorFunction inputFunction = inputs.get(0).function().get();
+ TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
+ if (shouldKeepDimensions()) {
+ // multiply with a generated tensor created from the reduced dimensions
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (String name : reduceDimensions) {
+ typeBuilder.indexed(name, 1);
+ }
+ TensorType generatedType = typeBuilder.build();
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
+ Generate generatedFunction = new Generate(generatedType,
+ new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
+ output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
+ }
+ return output;
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ List<String> renamedDimensions = new ArrayList<>(reduceDimensions.size());
+ for (String name : reduceDimensions) {
+ Optional<String> newName = renamer.dimensionNameOf(name);
+ if (!newName.isPresent()) {
+ return; // presumably, already renamed
+ }
+ renamedDimensions.add(newName.get());
+ }
+ reduceDimensions = renamedDimensions;
+ }
+
+ private boolean shouldKeepDimensions() {
+ AttrValue keepDimsAttr = node.getAttrMap().get("keep_dims");
+ return keepDimsAttr != null && keepDimsAttr.getB();
+ }
+
+ private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
+ for (TensorType.Dimension dimension: inputType.type().dimensions()) {
+ if (!reduceDimensions.contains(dimension.name())) {
+ builder.add(dimension);
+ } else if (keepDimensions) {
+ builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
+ }
+ }
+ return builder.build();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java
new file mode 100644
index 00000000000..d3561716725
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java
@@ -0,0 +1,36 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+public class Merge extends TensorFlowOperation {
+
+ public Merge(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ for (TensorFlowOperation operation : inputs) {
+ if (operation.type().isPresent()) {
+ return operation.type().get();
+ }
+ }
+ return null;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ for (TensorFlowOperation operation : inputs) {
+ if (operation.function().isPresent()) {
+ return operation.function().get();
+ }
+ }
+ return null;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java
new file mode 100644
index 00000000000..acf5d13b057
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java
@@ -0,0 +1,32 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.Collections;
+import java.util.List;
+
+public class NoOp extends TensorFlowOperation {
+
+ public NoOp(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, Collections.emptyList(), port); // don't propagate inputs
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return null;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null;
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java
new file mode 100644
index 00000000000..dadce395faf
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java
@@ -0,0 +1,57 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.VariableTensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+public class Placeholder extends TensorFlowOperation {
+
+ private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
+
+ public Placeholder(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ standardNamingType = OrderedTensorType.fromTensorFlowType(node);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
+ if (!standardNamingType.equals(type)) {
+ List<String> renameFrom = standardNamingType.dimensionNames();
+ List<String> renameTo = type.dimensionNames();
+ output = new Rename(output, renameFrom, renameTo);
+ }
+ return output;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public boolean isInput() {
+ return true;
+ }
+
+ @Override
+ public boolean isConstant() {
+ return false;
+ }
+
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java
new file mode 100644
index 00000000000..ab091b77a65
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java
@@ -0,0 +1,50 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Optional;
+
+public class PlaceholderWithDefault extends TensorFlowOperation {
+
+ public PlaceholderWithDefault(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) {
+ return null;
+ }
+ return inputs().get(0).type().orElse(null);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1)) {
+ return null;
+ }
+ // This should be a call to the macro we add below, but for now
+ // we treat this as as identity function and just pass the constant.
+ return inputs.get(0).function().orElse(null);
+ }
+
+ @Override
+ public Optional<RankingExpression> macro() {
+ // For now, it is much more efficient to assume we always will return
+ // the default value, as we can prune away large parts of the expression
+ // tree by having it calculated as a constant. If a case arises where
+ // it is important to support this, implement this.
+ return Optional.empty();
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true; // not true if we add to macro
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java
new file mode 100644
index 00000000000..9b3e28ce56b
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java
@@ -0,0 +1,135 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize;
+
+public class Reshape extends TensorFlowOperation {
+
+ public Reshape(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ TensorFlowOperation newShape = inputs.get(1);
+ if (!newShape.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("Reshape in " + node.getName() + ": " +
+ "shape input must be a constant.");
+ }
+ Tensor shape = newShape.getConstantValue().get().asTensor();
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(node);
+ int dimensionIndex = 0;
+ for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
+ Tensor.Cell cell = cellIterator.next();
+ int size = cell.getValue().intValue();
+ if (size < 0) {
+ size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() /
+ tensorSize(inputType.type()).intValue();
+ }
+ outputTypeBuilder.add(TensorType.Dimension.indexed(
+ String.format("%s_%d", vespaName(), dimensionIndex), size));
+ dimensionIndex++;
+ }
+ return outputTypeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ if (!allInputFunctionsPresent(2)) {
+ return null;
+ }
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ TensorFunction inputFunction = inputs.get(0).function().get();
+ return reshape(inputFunction, inputType.type(), type.type());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
+ if (!tensorSize(inputType).equals(tensorSize(outputType))) {
+ throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
+ }
+
+ // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
+ // then use the dimension order of the new shape to roll back into a tensor.
+ // Here we create a transformation tensor that is multiplied with the from tensor to map into
+ // the new shape. We have to introduce temporary dimension names and rename back if dimension names
+ // in the new and old tensor type overlap.
+
+ ExpressionNode unrollFrom = unrollTensorExpression(inputType);
+ ExpressionNode unrollTo = unrollTensorExpression(outputType);
+ ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
+
+ TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
+ Generate transformTensor = new Generate(transformationType,
+ new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
+
+ TensorFunction outputFunction = new Reduce(
+ new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
+
+ return outputFunction;
+ }
+
+ private static ExpressionNode unrollTensorExpression(TensorType type) {
+ if (type.rank() == 0) {
+ return new ConstantNode(DoubleValue.zero);
+ }
+ List<ExpressionNode> children = new ArrayList<>();
+ List<ArithmeticOperator> operators = new ArrayList<>();
+ int size = 1;
+ for (int i = type.dimensions().size() - 1; i >= 0; --i) {
+ TensorType.Dimension dimension = type.dimensions().get(i);
+ children.add(0, new ReferenceNode(dimension.name()));
+ if (size > 1) {
+ operators.add(0, ArithmeticOperator.MULTIPLY);
+ children.add(0, new ConstantNode(new DoubleValue(size)));
+ }
+ size *= TensorConverter.dimensionSize(dimension);
+ if (i > 0) {
+ operators.add(0, ArithmeticOperator.PLUS);
+ }
+ }
+ return new ArithmeticNode(children, operators);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java
new file mode 100644
index 00000000000..6a29d428cf3
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java
@@ -0,0 +1,89 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.function.DoubleBinaryOperator;
+
+import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.dimensionSize;
+import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize;
+
+public class Select extends TensorFlowOperation {
+
+ public Select(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(3)) {
+ return null;
+ }
+ OrderedTensorType a = inputs.get(1).type().get();
+ OrderedTensorType b = inputs.get(2).type().get();
+ if ((a.type().rank() != b.type().rank()) || !(tensorSize(a.type()).equals(tensorSize(b.type())))) {
+ throw new IllegalArgumentException("'Select': input tensors must have the same shape");
+ }
+ return a;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(3)) {
+ return null;
+ }
+ TensorFlowOperation conditionOperation = inputs().get(0);
+ TensorFunction a = inputs().get(1).function().get();
+ TensorFunction b = inputs().get(2).function().get();
+
+ // Shortcut: if we know during import which tensor to select, do that directly here.
+ if (conditionOperation.getConstantValue().isPresent()) {
+ Tensor condition = conditionOperation.getConstantValue().get().asTensor();
+ if (condition.type().rank() == 0) {
+ return ((int) condition.asDouble() == 0) ? b : a;
+ }
+ if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
+ return condition.cellIterator().next().getValue().intValue() == 0 ? b : a;
+ }
+ }
+
+ // The task is to select cells from 'x' or 'y' based on 'condition'.
+ // If 'condition' is 0 (false), select from 'y', if 1 (true) select
+ // from 'x'. We do this by individually joining 'x' and 'y' with
+ // 'condition', and then joining the resulting two tensors.
+
+ TensorFunction conditionFunction = conditionOperation.function().get();
+ TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply());
+ TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() {
+ @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); }
+ @Override public String toString() { return "f(a,b)(a * (1-b))"; }
+ });
+ return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(3)) {
+ return;
+ }
+ List<TensorType.Dimension> aDimensions = inputs.get(1).type().get().dimensions();
+ List<TensorType.Dimension> bDimensions = inputs.get(2).type().get().dimensions();
+
+ String aDim0 = aDimensions.get(0).name();
+ String aDim1 = aDimensions.get(1).name();
+ String bDim0 = bDimensions.get(0).name();
+ String bDim1 = bDimensions.get(1).name();
+
+ // These tensors should have the same dimension names
+ renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this);
+ renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java
new file mode 100644
index 00000000000..8f4313022e0
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java
@@ -0,0 +1,55 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+public class Shape extends TensorFlowOperation {
+
+ public Shape(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ createConstantValue();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) {
+ return null;
+ }
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ return new OrderedTensorType.Builder(node)
+ .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size()))
+ .build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null; // will be added by function() since this is constant.
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+ private void createConstantValue() {
+ if (!allInputTypesPresent(1)) {
+ return;
+ }
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type().get().type());
+ List<TensorType.Dimension> inputDimensions = inputType.dimensions();
+ for (int i = 0; i < inputDimensions.size(); i++) {
+ builder.cellByDirectIndex(i, inputDimensions.get(i).size().orElse(-1L));
+ }
+ this.setConstantValue(new TensorValue(builder.build()));
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java
new file mode 100644
index 00000000000..d7750b52fc3
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java
@@ -0,0 +1,84 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+public class Squeeze extends TensorFlowOperation {
+
+ private List<String> squeezeDimensions;
+
+ public Squeeze(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) {
+ return null;
+ }
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ squeezeDimensions = new ArrayList<>();
+
+ AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims");
+ if (squeezeDimsAttr == null) {
+ squeezeDimensions = inputType.type().dimensions().stream().
+ filter(dim -> TensorConverter.dimensionSize(dim) == 1).
+ map(TensorType.Dimension::name).
+ collect(Collectors.toList());
+ } else {
+ squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().
+ map(i -> i < 0 ? inputType.type().dimensions().size() - i : i).
+ map(i -> inputType.type().dimensions().get(i.intValue())).
+ filter(dim -> TensorConverter.dimensionSize(dim) == 1).
+ map(TensorType.Dimension::name).
+ collect(Collectors.toList());
+ }
+ return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1)) {
+ return null;
+ }
+ TensorFunction inputFunction = inputs.get(0).function().get();
+ return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ List<String> renamedDimensions = new ArrayList<>(squeezeDimensions.size());
+ for (String name : squeezeDimensions) {
+ Optional<String> newName = renamer.dimensionNameOf(name);
+ if (!newName.isPresent()) {
+ return; // presumably, already renamed
+ }
+ renamedDimensions.add(newName.get());
+ }
+ squeezeDimensions = renamedDimensions;
+ }
+
+ private OrderedTensorType reducedType(OrderedTensorType inputType) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
+ for (TensorType.Dimension dimension: inputType.type().dimensions()) {
+ if ( ! squeezeDimensions.contains(dimension.name())) {
+ builder.add(dimension);
+ }
+ }
+ return builder.build();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java
new file mode 100644
index 00000000000..1cc0e1936eb
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java
@@ -0,0 +1,48 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Optional;
+
+public class Switch extends TensorFlowOperation {
+
+ public Switch(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ Optional<OrderedTensorType> predicate = inputs.get(1).type();
+ if (predicate.get().type().rank() != 0) {
+ throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ "predicate must be a scalar");
+ }
+ return inputs.get(0).type().orElse(null);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ TensorFlowOperation predicateOperation = inputs().get(1);
+ if (!predicateOperation.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ "predicate must be a constant");
+ }
+ if (port < 0 || port > 1) {
+ throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ "choice should be boolean");
+ }
+
+ double predicate = predicateOperation.getConstantValue().get().asDouble();
+ return predicate == port ? inputs().get(0).function().get() : null;
+ }
+
+}
+
+
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
new file mode 100644
index 00000000000..fd9dfd167fb
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
@@ -0,0 +1,136 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.function.Function;
+
+/**
+ * Wraps a TensorFlow node and produces the respective Vespa tensor operation.
+ * During import, a graph of these operations are constructed. Then, the
+ * types are used to deduce sensible dimension names using the
+ * DimensionRenamer. After the types have been renamed, the proper
+ * Vespa expressions can be extracted.
+ *
+ * @author lesters
+ */
+public abstract class TensorFlowOperation {
+
+ protected final NodeDef node;
+ protected final int port;
+ protected final List<TensorFlowOperation> inputs;
+ protected final List<TensorFlowOperation> outputs = new ArrayList<>();
+ protected final List<String> importWarnings = new ArrayList<>();
+
+ protected OrderedTensorType type;
+ protected TensorFunction function;
+
+ private Value constantValue = null;
+ private List<TensorFlowOperation> controlInputs = Collections.emptyList();
+
+ TensorFlowOperation(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ this.node = node;
+ this.port = port;
+ this.inputs = Collections.unmodifiableList(inputs);
+ this.inputs.forEach(i -> i.outputs.add(this));
+ }
+
+ protected abstract OrderedTensorType lazyGetType();
+ protected abstract TensorFunction lazyGetFunction();
+
+ /** Returns the Vespa tensor type of this operation if it exists */
+ public Optional<OrderedTensorType> type() {
+ if (type == null) {
+ type = lazyGetType();
+ }
+ OrderedTensorType.verifyType(node, type);
+ return Optional.ofNullable(type);
+ }
+
+ /** Returns the Vespa tensor function implementing all operations from this node with inputs */
+ public Optional<TensorFunction> function() {
+ if (function == null) {
+ if (isConstant()) {
+ ExpressionNode constant = new ReferenceNode("constant(\"" + vespaName() + "\")");
+ function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
+ } else {
+ function = lazyGetFunction();
+ }
+ }
+ return Optional.ofNullable(function);
+ }
+
+ /** Return TensorFlow node */
+ public NodeDef node() { return node; }
+
+ /** Return unmodifiable list of inputs */
+ public List<TensorFlowOperation> inputs() { return inputs; }
+
+ /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */
+ public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); }
+
+ /** Returns a Vespa ranking expression that should be added as a macro */
+ public Optional<RankingExpression> macro() { return Optional.empty(); }
+
+ /** Add dimension name constraints for this operation */
+ public void addDimensionNameConstraints(DimensionRenamer renamer) { }
+
+ /** Performs dimension rename for this operation */
+ public void renameDimensions(DimensionRenamer renamer) { type = OrderedTensorType.rename(type, renamer); }
+
+ /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
+ public boolean isInput() { return false; }
+
+ /** Return true if this node is constant */
+ public boolean isConstant() { return inputs.stream().allMatch(TensorFlowOperation::isConstant); }
+
+ /** Sets the constant value */
+ public void setConstantValue(Value value) { constantValue = value; }
+
+ /** Gets the constant value if it exists */
+ public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); }
+
+ /** Sets the external control inputs */
+ public void setControlInputs(List<TensorFlowOperation> inputs) { this.controlInputs = inputs; }
+
+ /** Retrieve the control inputs for this operation */
+ public List<TensorFlowOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
+
+ /** Retrieve the valid Vespa name of this node */
+ public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; }
+
+ /** Retrieve the list of warnings produced during its lifetime */
+ public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
+
+ boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) {
+ if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) {
+ return false;
+ }
+ if (inputs.size() != expected) {
+ throw new IllegalArgumentException("Expected " + expected + " inputs " +
+ "for '" + node.getName() + "', got " + inputs.size());
+ }
+ return inputs.stream().map(func).allMatch(Optional::isPresent);
+ }
+
+ boolean allInputTypesPresent(int expected) {
+ return verifyInputs(expected, TensorFlowOperation::type);
+ }
+
+ boolean allInputFunctionsPresent(int expected) {
+ return verifyInputs(expected, TensorFlowOperation::function);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
new file mode 100644
index 00000000000..6f377c4bda2
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
@@ -0,0 +1,40 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+public class Variable extends TensorFlowOperation {
+
+ public Variable(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null; // will be added by function() since this is constant.
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java
index fb9a7cb9ad7..d3a12d0f312 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java
@@ -13,7 +13,7 @@ import java.util.List;
/**
* A set of argument expressions to a function or feature.
- * This is immutable.
+ * This is a value object.
*
* @author bratseth
*/
@@ -22,7 +22,11 @@ public final class Arguments implements Serializable {
private final ImmutableList<ExpressionNode> expressions;
public Arguments() {
- this(null);
+ this(ImmutableList.of());
+ }
+
+ public Arguments(ExpressionNode singleArgument) {
+ this(ImmutableList.of(singleArgument));
}
public Arguments(List<? extends ExpressionNode> expressions) {
@@ -38,9 +42,12 @@ public final class Arguments implements Serializable {
this.expressions = b.build();
}
- /** Returns an unmodifiable list of the expressions in this */
+ /** Returns an unmodifiable list of the expressions in this, never null */
public List<ExpressionNode> expressions() { return expressions; }
+ /** Returns the number of arguments in this */
+ public int size() { return expressions.size(); }
+
/** Evaluate all arguments in this */
public Value[] evaluate(Context context) {
Value[] values=new Value[expressions.size()];
@@ -62,8 +69,9 @@ public final class Arguments implements Serializable {
}
@Override
- public boolean equals(Object rhs) {
- return rhs instanceof Arguments && expressions.equals(((Arguments)rhs).expressions);
+ public boolean equals(Object other) {
+ if (other == this) return true;
+ return other instanceof Arguments && expressions.equals(((Arguments)other).expressions);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
index fc6428a4c33..49c49bed9bd 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -80,7 +81,7 @@ public final class ArithmeticNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
// Compute type using tensor types as arithmetic operators are supported on tensors
// and is correct also in the special case of doubles.
// As all our functions are type-commutative, we don't need to take operator precedence into account
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java
index 1d7d9b1ecda..cd4ddbcae55 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java
@@ -5,7 +5,6 @@ package com.yahoo.searchlib.rankingexpression.rule;
* A node which produces a boolean value when evaluated.
*
* @author bratseth
- * @since 5.1.21
*/
public abstract class BooleanNode extends CompositeNode {
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
index 7601c0e6180..eb328486045 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -49,7 +50,7 @@ public class ComparisonNode extends BooleanNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return TensorType.empty; // by definition
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
index 1ea8d03f0eb..3ddd7223349 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -49,7 +50,7 @@ public final class ConstantNode extends ExpressionNode {
}
@Override
- public TensorType type(TypeContext context) { return value.type(); }
+ public TensorType type(TypeContext<Reference> context) { return value.type(); }
@Override
public Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
index fd9fab99db8..47c2897e4a4 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -50,7 +51,7 @@ public final class EmbracedNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return value.type(context);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
index 477f4db4981..6bb163590de 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -48,7 +49,7 @@ public abstract class ExpressionNode implements Serializable {
* @param context the variable type bindings to use for this evaluation
* @throws IllegalArgumentException if there are variables which are not bound in the given map
*/
- public abstract TensorType type(TypeContext context);
+ public abstract TensorType type(TypeContext<Reference> context);
/**
* Returns the value of evaluating this expression over the given context.
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
index 79515229019..1da2210a39c 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -67,7 +68,7 @@ public final class FunctionNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
if (arguments.expressions().size() == 0)
return TensorType.empty;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java
new file mode 100644
index 00000000000..ed1e2838717
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java
@@ -0,0 +1,74 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * The context of a function invocation.
+ *
+ * @author bratseth
+ */
+public class FunctionReferenceContext {
+
+ /** Expression functions indexed by name */
+ private final ImmutableMap<String, ExpressionFunction> functions;
+
+ /** Mapping from argument names to the expressions they resolve to */
+ // TODO: Make private
+ public final Map<String, String> bindings = new HashMap<>();
+
+ /** Create a context for a single serialization task */
+ public FunctionReferenceContext() {
+ this(Collections.emptyList());
+ }
+
+ /** Create a context for a single serialization task */
+ public FunctionReferenceContext(Collection<ExpressionFunction> functions) {
+ this(toMap(functions), Collections.emptyMap());
+ }
+
+ public FunctionReferenceContext(Collection<ExpressionFunction> functions, Map<String, String> bindings) {
+ this(toMap(functions), bindings);
+ }
+
+ /** Create a context for a single serialization task */
+ public FunctionReferenceContext(Map<String, ExpressionFunction> functions) {
+ this(functions.values());
+ }
+
+ /** Create a context for a single serialization task */
+ public FunctionReferenceContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings) {
+ this.functions = ImmutableMap.copyOf(functions);
+ if (bindings != null)
+ this.bindings.putAll(bindings);
+ }
+
+ private static ImmutableMap<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) {
+ ImmutableMap.Builder<String,ExpressionFunction> mapBuilder = new ImmutableMap.Builder<>();
+ for (ExpressionFunction function : list)
+ mapBuilder.put(function.getName(), function);
+ return mapBuilder.build();
+ }
+
+ /**
+ * Returns a function or null if it isn't defined in this context
+ */
+ public ExpressionFunction getFunction(String name) { return functions.get(name); }
+
+ protected Map<String, ExpressionFunction> functions() { return functions; }
+
+ /** Returns the resolution of an argument, or null if it isn't defined in this context */
+ public String getBinding(String name) { return bindings.get(name); }
+
+ /** Returns a new context with the bindings replaced by the given bindings */
+ public FunctionReferenceContext withBindings(Map<String, String> bindings) {
+ return new FunctionReferenceContext(this.functions, bindings);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
index e42884ecc05..c87eb0ace39 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -48,7 +49,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) { return type; }
+ public TensorType type(TypeContext<Reference> context) { return type; }
/** Evaluate this in a context which must have the arguments bound */
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
index 66b250736e8..ee4edac4941 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -75,7 +76,7 @@ public final class IfNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
TensorType trueType = trueExpression.type(context);
TensorType falseType = falseExpression.type(context);
return trueType.dimensionwiseGeneralizationWith(falseType).orElseThrow(() ->
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index da946228291..61086f8182a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -57,7 +58,7 @@ public class LambdaFunctionNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return TensorType.empty; // by definition - no nested lambdas
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
index f55ed59b65c..f1adf331630 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -14,6 +15,7 @@ import java.util.Deque;
*
* @author Simon Thoresen
*/
+// TODO: This is achieved by ReferenceNode in almost all cases - remove this
public final class NameNode extends ExpressionNode {
private final String name;
@@ -32,7 +34,7 @@ public final class NameNode extends ExpressionNode {
}
@Override
- public TensorType type(TypeContext context) { throw new RuntimeException("Named nodes can not have a type"); }
+ public TensorType type(TypeContext<Reference> context) { throw new RuntimeException("Named nodes can not have a type"); }
@Override
public Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
index 9cbe5f98c72..fcc03dc4862 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -38,7 +39,7 @@ public class NegativeNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return value.type(context);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
index e7041600635..a539f496ff5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -38,7 +39,7 @@ public class NotNode extends BooleanNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return value.type(context);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index 05a6773c5cb..78f53b1593d 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -13,114 +14,102 @@ import java.util.Deque;
import java.util.List;
/**
- * A node referring either to a value in the context or to another named ranking expression.
+ * A node referring either to a value in the context or to a named ranking expression (function aka macro).
*
* @author simon
* @author bratseth
*/
public final class ReferenceNode extends CompositeNode {
- private final String name, output;
-
- private final Arguments arguments;
+ private final Reference reference;
+ /* Creates a node with a simple identifier reference */
public ReferenceNode(String name) {
this(name, null, null);
}
public ReferenceNode(String name, List<? extends ExpressionNode> arguments, String output) {
- this.name = name;
- this.arguments = arguments != null ? new Arguments(arguments) : new Arguments();
- this.output = output;
+ this.reference = new Reference(name,
+ arguments != null ? new Arguments(arguments) : new Arguments(),
+ output);
+ }
+
+ public ReferenceNode(Reference reference) {
+ this.reference = reference;
}
public String getName() {
- return name;
+ return reference.name();
}
/** Returns the arguments, never null */
- public Arguments getArguments() { return arguments; }
+ public Arguments getArguments() { return reference.arguments(); }
/** Returns a copy of this where the arguments are replaced by the given arguments */
public ReferenceNode setArguments(List<ExpressionNode> arguments) {
- return new ReferenceNode(name, arguments, output);
+ return new ReferenceNode(reference.withArguments(new Arguments(arguments)));
}
/** Returns the specific output this references, or null if none specified */
- public String getOutput() { return output; }
+ public String getOutput() { return reference.output(); }
/** Returns a copy of this node with a modified output */
public ReferenceNode setOutput(String output) {
- return new ReferenceNode(name, arguments.expressions(), output);
+ return new ReferenceNode(reference.withOutput(output));
}
/** Returns an empty list as this has no children */
@Override
- public List<ExpressionNode> children() { return arguments.expressions(); }
+ public List<ExpressionNode> children() { return reference.arguments().expressions(); }
@Override
public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- if (path == null)
- path = new ArrayDeque<>();
- String myName = this.name;
- String myOutput = this.output;
- List<ExpressionNode> myArguments = this.arguments.expressions();
-
- String resolvedArgument = context.getBinding(myName);
- if (resolvedArgument != null && this.arguments.expressions().size() == 0 && myOutput == null) {
- // Replace this whole node with the value of the argument value that it maps to
- myName = resolvedArgument;
- myArguments = null;
- myOutput = null;
- } else if (context.getFunction(myName) != null) {
- // Replace by the referenced expression
- ExpressionFunction function = context.getFunction(myName);
- if (function != null && myArguments != null && function.arguments().size() == myArguments.size() && myOutput == null) {
- String myPath = name + this.arguments.expressions();
- if (path.contains(myPath)) {
- throw new IllegalStateException("Cycle in ranking expression function: " + path);
- }
- path.addLast(myPath);
- ExpressionFunction.Instance instance = function.expand(context, myArguments, path);
- path.removeLast();
- context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString());
- myName = "rankingExpression(" + instance.getName() + ")";
- myArguments = null;
- myOutput = null;
- }
+ if (reference.isIdentifier() && context.getBinding(getName()) != null) {
+ // a bound identifier: replace by the value it is bound to
+ return context.getBinding(getName());
}
- // Always print the same way, the magic is already done.
- StringBuilder ret = new StringBuilder(myName);
- if (myArguments != null && myArguments.size() > 0) {
- ret.append("(");
- for (int i = 0; i < myArguments.size(); ++i) {
- ret.append(myArguments.get(i).toString(context, path, this));
- if (i < myArguments.size() - 1) {
- ret.append(",");
- }
- }
- ret.append(")");
+
+ ExpressionFunction function = context.getFunction(getName());
+ if (function != null && function.arguments().size() == getArguments().size() && getOutput() == null) {
+ // a function reference: replace by the referenced function wrapped in rankingExpression
+ if (path == null)
+ path = new ArrayDeque<>();
+ String myPath = getName() + getArguments().expressions();
+ if (path.contains(myPath))
+ throw new IllegalStateException("Cycle in ranking expression function: " + path);
+ path.addLast(myPath);
+ ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path);
+ path.removeLast();
+ context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString());
+ return "rankingExpression(" + instance.getName() + ")";
}
- ret.append(myOutput != null ? "." + myOutput : "");
- return ret.toString();
+
+ // not resolved in this context: output as-is
+ return reference.toString(context, path, parent);
}
+ /** Returns the reference of this node */
+ public Reference reference() { return reference; }
+
@Override
- public TensorType type(TypeContext context) {
- // Don't support outputs of different type, for simplicity
- return context.getType(toString());
+ public TensorType type(TypeContext<Reference> context) {
+ TensorType type = context.getType(reference);
+ if (type == null)
+ throw new IllegalArgumentException("Unknown feature '" + toString() + "'");
+ return type;
}
@Override
public Value evaluate(Context context) {
- if (arguments.expressions().isEmpty() && output == null)
- return context.get(name);
- return context.get(name, arguments, output);
+ // TODO: Context should accept a Reference instead.
+ if (reference.isIdentifier())
+ return context.get(reference.name());
+ return context.get(getName(), getArguments(), getOutput());
}
@Override
public CompositeNode setChildren(List<ExpressionNode> newChildren) {
- return new ReferenceNode(name, newChildren, output);
+ return setArguments(newChildren);
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
index ba765d07094..796c13a8669 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
@@ -16,17 +16,11 @@ import java.util.Map;
*
* @author bratseth
*/
-public class SerializationContext {
+public class SerializationContext extends FunctionReferenceContext {
- /** Expression functions indexed by name */
- private final ImmutableMap<String, ExpressionFunction> functions;
-
- /** A cache of already serialized expressions indexed by name */
+ /** Serialized form of functions indexed by name */
private final Map<String, String> serializedFunctions;
- /** Mapping from argument names to the expressions they resolve to */
- public final Map<String, String> bindings = new HashMap<>();
-
/** Create a context for a single serialization task */
public SerializationContext() {
this(Collections.emptyList());
@@ -77,17 +71,10 @@ public class SerializationContext {
*/
public SerializationContext(ImmutableMap<String,ExpressionFunction> functions, Map<String, String> bindings,
Map<String, String> serializedFunctions) {
- this.functions = functions;
+ super(functions, bindings);
this.serializedFunctions = serializedFunctions;
- if (bindings != null)
- this.bindings.putAll(bindings);
}
- /**
- * Returns a function or null if it isn't defined in this context
- */
- public ExpressionFunction getFunction(String name) { return functions.get(name); }
-
/** Adds the serialization of a function */
public void addFunctionSerialization(String name, String expressionString) {
serializedFunctions.put(name, expressionString);
@@ -98,17 +85,9 @@ public class SerializationContext {
return serializedFunctions.get(name);
}
- /**
- * Returns the resolution of an argument, or null if it isn't defined in this context
- */
- public String getBinding(String name) { return bindings.get(name); }
-
- /**
- * Returns a new context which shares the functions and serialized function map with this but has different
- * arguments.
- */
- public SerializationContext createBinding(Map<String, String> arguments) {
- return new SerializationContext(this.functions, arguments, this.serializedFunctions);
+ @Override
+ public SerializationContext withBindings(Map<String, String> bindings) {
+ return new SerializationContext(functions().values(), bindings, this.serializedFunctions);
}
public Map<String, String> serializedFunctions() { return serializedFunctions; }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
index a7b82f4753f..cb31219579a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -60,7 +61,7 @@ public class SetMembershipNode extends BooleanNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return TensorType.empty;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index ec6af4bb413..6c9b6bb4a98 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.annotations.Beta;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -64,7 +65,7 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) { return function.type(context); }
+ public TensorType type(TypeContext<Reference> context) { return function.type(context); }
@Override
public Value evaluate(Context context) {
@@ -111,12 +112,13 @@ public class TensorFunctionNode extends CompositeNode {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public TensorType type(TypeContext context) {
- return expression.type(context);
+ @SuppressWarnings("unchecked") // Generics awkwardness
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ return expression.type((TypeContext<Reference>)context);
}
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
return expression.evaluate((Context)context).asTensor();
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index e9030cf5852..f2122bb5da9 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -378,8 +378,13 @@ public class EvaluationTestCase {
private static class StructuredTestContext extends MapContext {
@Override
+ public Value get(String feature) {
+ throw new RuntimeException("Called simple get for feature " + feature);
+ }
+
+ @Override
public Value get(String name, Arguments arguments, String output) {
- if (!name.equals("average")) {
+ if ( ! name.equals("average")) {
throw new IllegalArgumentException("Unknown operation '" + name + "'");
}
if (arguments.expressions().size() != 2) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java
index c882c887c8d..a08d510eec4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java
@@ -3,6 +3,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -18,12 +19,17 @@ public class TypeResolutionTestCase {
@Test
public void testTypeResolution() {
- TypeMapContext context = new TypeMapContext();
- context.setType("query(x1)", TensorType.fromSpec("tensor(x[])"));
- context.setType("query(x2)", TensorType.fromSpec("tensor(x[10])"));
- context.setType("query(y1)", TensorType.fromSpec("tensor(y[])"));
- context.setType("query(xy1)", TensorType.fromSpec("tensor(x[10],y[])"));
- context.setType("query(xy2)", TensorType.fromSpec("tensor(x[],y[10])"));
+ MapTypeContext context = new MapTypeContext();
+ context.setType(Reference.simple("query", "x1"),
+ TensorType.fromSpec("tensor(x[])"));
+ context.setType(Reference.simple("query", "x2"),
+ TensorType.fromSpec("tensor(x[10])"));
+ context.setType(Reference.simple("query", "y1"),
+ TensorType.fromSpec("tensor(y[])"));
+ context.setType(Reference.simple("query", "xy1"),
+ TensorType.fromSpec("tensor(x[10],y[])"));
+ context.setType(Reference.simple("query", "xy2"),
+ TensorType.fromSpec("tensor(x[],y[10])"));
assertType("tensor(x[])", "query(x1)", context);
assertType("tensor(x[])", "if (1>0, query(x1), query(x2))", context);
@@ -31,7 +37,7 @@ public class TypeResolutionTestCase {
assertIncompatibleType("if (1>0, query(x1), query(y1))", context);
}
- private void assertType(String type, String expression, TypeContext context) {
+ private void assertType(String type, String expression, TypeContext<Reference> context) {
try {
assertEquals(TensorType.fromSpec(type), new RankingExpression(expression).type(context));
}
@@ -40,7 +46,7 @@ public class TypeResolutionTestCase {
}
}
- private void assertIncompatibleType(String expression, TypeContext context) {
+ private void assertIncompatibleType(String expression, TypeContext<Reference> context) {
try {
new RankingExpression(expression).type(context);
fail("Expected type incompatibility exception");
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java
new file mode 100644
index 00000000000..ebcfde54c70
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java
@@ -0,0 +1,49 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import org.junit.Test;
+
+import static junit.framework.TestCase.assertTrue;
+
+public class DimensionRenamerTest {
+
+ @Test
+ public void testMnistRenaming() {
+ DimensionRenamer renamer = new DimensionRenamer();
+
+ renamer.addDimension("first_dimension_of_x");
+ renamer.addDimension("second_dimension_of_x");
+ renamer.addDimension("first_dimension_of_w");
+ renamer.addDimension("second_dimension_of_w");
+ renamer.addDimension("first_dimension_of_b");
+
+ // which dimension to join on matmul
+ renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null);
+
+ // other dimensions in matmul can't be equal
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null);
+
+ // for efficiency, put dimension to join on innermost
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null);
+ renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null);
+
+ // bias
+ renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null);
+
+ renamer.solve();
+
+ String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get();
+ String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get();
+ String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get();
+ String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get();
+ String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get();
+
+ assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0);
+ assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0);
+ assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0);
+ assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0);
+ assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0);
+
+
+ }
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
index 3b25bfe1b1e..f64d697d9b9 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
@@ -18,11 +18,6 @@ public class DropoutImportTestCase {
public void testDropoutImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved");
- // Check (provided) macros
- assertEquals(1, model.get().macros().size());
- assertTrue(model.get().macros().containsKey("training_input"));
- assertEquals("constant(\"training_input\")", model.get().macros().get("training_input").getRoot().toString());
-
// Check required macros
assertEquals(1, model.get().requiredMacros().size());
assertTrue(model.get().requiredMacros().containsKey("X"));
@@ -37,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/BiasAdd", output.getName());
- assertEquals("join(rename(reduce(join(X, rename(constant(\"outputs_kernel\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"outputs_bias\"), d0, d1), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(rename(X, (d0, d1), (d0, d2)), constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index ad5abd4c03d..60dd3865aa1 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -22,15 +22,15 @@ public class MnistSoftmaxImportTestCase {
// Check constants
assertEquals(2, model.get().largeConstants().size());
- Tensor constant0 = model.get().largeConstants().get("Variable");
+ Tensor constant0 = model.get().largeConstants().get("Variable_read");
assertNotNull(constant0);
- assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
+ assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.get().largeConstants().get("Variable_1");
+ Tensor constant1 = model.get().largeConstants().get("Variable_1_read");
assertNotNull(constant1);
- assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
+ assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
constant1.type());
assertEquals(10, constant1.size());
@@ -59,12 +59,10 @@ public class MnistSoftmaxImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("add", output.getName());
- assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"Variable_1_read\"), f(a,b)(a + b))",
output.getRoot().toString());
// Test execution
- model.assertEqualResult("Placeholder", "Variable/read");
- model.assertEqualResult("Placeholder", "Variable_1/read");
model.assertEqualResult("Placeholder", "MatMul");
model.assertEqualResult("Placeholder", "add");
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index ae7714b271a..1691756a64d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.tensorflow.SavedModelBundle;
@@ -47,8 +48,11 @@ public class TestableTensorFlowModel {
private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
Session.Runner runner = model.session().runner();
- org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size },
- FloatBuffer.allocate(d0Size * d1Size));
+ FloatBuffer fb = FloatBuffer.allocate(d0Size * d1Size);
+ for (int i = 0; i < d1Size; ++i) {
+ fb.put(i, (float)(i * 1.0 / d1Size));
+ }
+ org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb);
runner.feed(inputName, placeholder);
List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
assertEquals(1, results.size());
@@ -66,7 +70,7 @@ public class TestableTensorFlowModel {
Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build());
for (int d0 = 0; d0 < d0Size; d0++)
for (int d1 = 0; d1 < d1Size; d1++)
- b.cell(0, d0, d1);
+ b.cell(d1 * 1.0 / d1Size, d0, d1);
return b.build();
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java
index 867331e99ce..303135888d8 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java
@@ -9,13 +9,13 @@ import java.util.Collections;
import static org.junit.Assert.*;
/**
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
*/
public class ArgumentsTestCase {
@Test
public void requireThatAccessorsWork() {
- Arguments args = new Arguments(null);
+ Arguments args = new Arguments();
assertTrue(args.expressions().isEmpty());
args = new Arguments(Collections.<ExpressionNode>emptyList());
diff --git a/storage/src/tests/common/testnodestateupdater.cpp b/storage/src/tests/common/testnodestateupdater.cpp
index 18f296e5583..c7fd47e37c7 100644
--- a/storage/src/tests/common/testnodestateupdater.cpp
+++ b/storage/src/tests/common/testnodestateupdater.cpp
@@ -1,7 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "testnodestateupdater.h"
-#include <vespa/storage/common/cluster_state_bundle.h>
+#include <vespa/vdslib/state/cluster_state_bundle.h>
namespace storage {
@@ -14,7 +14,7 @@ TestNodeStateUpdater::TestNodeStateUpdater(const lib::NodeType& type)
TestNodeStateUpdater::~TestNodeStateUpdater() = default;
-std::shared_ptr<const ClusterStateBundle>
+std::shared_ptr<const lib::ClusterStateBundle>
TestNodeStateUpdater::getClusterStateBundle() const
{
return _clusterStateBundle;
@@ -23,7 +23,7 @@ TestNodeStateUpdater::getClusterStateBundle() const
void
TestNodeStateUpdater::setClusterState(lib::ClusterState::CSP c)
{
- _clusterStateBundle = std::make_shared<const ClusterStateBundle>(*c);
+ _clusterStateBundle = std::make_shared<const lib::ClusterStateBundle>(*c);
for (uint32_t i = 0; i < _listeners.size(); ++i) {
_listeners[i]->handleNewState();
}
diff --git a/storage/src/tests/common/testnodestateupdater.h b/storage/src/tests/common/testnodestateupdater.h
index daecb45ece4..1e898e84b18 100644
--- a/storage/src/tests/common/testnodestateupdater.h
+++ b/storage/src/tests/common/testnodestateupdater.h
@@ -16,7 +16,7 @@ struct TestNodeStateUpdater : public NodeStateUpdater
{
lib::NodeState::CSP _reported;
lib::NodeState::CSP _current;
- std::shared_ptr<const ClusterStateBundle> _clusterStateBundle;
+ std::shared_ptr<const lib::ClusterStateBundle> _clusterStateBundle;
std::vector<StateListener*> _listeners;
public:
@@ -25,7 +25,7 @@ public:
lib::NodeState::CSP getReportedNodeState() const override { return _reported; }
lib::NodeState::CSP getCurrentNodeState() const override { return _current; }
- std::shared_ptr<const ClusterStateBundle> getClusterStateBundle() const override;
+ std::shared_ptr<const lib::ClusterStateBundle> getClusterStateBundle() const override;
void addStateListener(StateListener& s) override { _listeners.push_back(&s); }
void removeStateListener(StateListener&) override {}
Lock::SP grabStateChangeLock() override { return Lock::SP(new Lock); }
diff --git a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp
index 2192ae4d634..248fb1e5203 100644
--- a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp
+++ b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp
@@ -7,7 +7,7 @@
#include <vespa/document/test/make_document_bucket.h>
#include <vespa/storage/storageserver/statemanager.h>
#include <vespa/storage/bucketdb/bucketmanager.h>
-#include <vespa/storage/common/cluster_state_bundle.h>
+#include <vespa/vdslib/state/cluster_state_bundle.h>
#include <vespa/storage/persistence/persistencethread.h>
#include <vespa/storage/persistence/filestorage/filestormanager.h>
#include <vespa/storage/persistence/filestorage/modifiedbucketchecker.h>
diff --git a/storage/src/tests/storageserver/statemanagertest.cpp b/storage/src/tests/storageserver/statemanagertest.cpp
index 0676d3684ff..7c5303f74fe 100644
--- a/storage/src/tests/storageserver/statemanagertest.cpp
+++ b/storage/src/tests/storageserver/statemanagertest.cpp
@@ -4,7 +4,7 @@
#include <vespa/metrics/metricmanager.h>
#include <vespa/storageapi/message/bucket.h>
#include <vespa/storageapi/message/state.h>
-#include <vespa/storage/common/cluster_state_bundle.h>
+#include <vespa/vdslib/state/cluster_state_bundle.h>
#include <vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h>
#include <vespa/storage/storageserver/statemanager.h>
#include <tests/common/teststorageapp.h>
diff --git a/storage/src/vespa/storage/bucketdb/bucketmanager.cpp b/storage/src/vespa/storage/bucketdb/bucketmanager.cpp
index 142003735b8..5078d35956a 100644
--- a/storage/src/vespa/storage/bucketdb/bucketmanager.cpp
+++ b/storage/src/vespa/storage/bucketdb/bucketmanager.cpp
@@ -7,7 +7,7 @@
#include <iomanip>
#include <vespa/storage/common/content_bucket_space_repo.h>
#include <vespa/storage/common/nodestateupdater.h>
-#include <vespa/storage/common/cluster_state_bundle.h>
+#include <vespa/vdslib/state/cluster_state_bundle.h>
#include <vespa/storage/storageutil/distributorstatecache.h>
#include <vespa/storageframework/generic/status/htmlstatusreporter.h>
#include <vespa/storageframework/generic/status/xmlstatusreporter.h>
diff --git a/storage/src/vespa/storage/common/CMakeLists.txt b/storage/src/vespa/storage/common/CMakeLists.txt
index d1e819523d7..c53aead2ba2 100644
--- a/storage/src/vespa/storage/common/CMakeLists.txt
+++ b/storage/src/vespa/storage/common/CMakeLists.txt
@@ -3,7 +3,6 @@ vespa_add_library(storage_common OBJECT
SOURCES
bucketmessages.cpp
bucketoperationlogger.cpp
- cluster_state_bundle.cpp
content_bucket_space.cpp
content_bucket_space_repo.cpp
distributorcomponent.cpp
diff --git a/storage/src/vespa/storage/common/nodestateupdater.h b/storage/src/vespa/storage/common/nodestateupdater.h
index 7fd3dedbcab..c2887a971f3 100644
--- a/storage/src/vespa/storage/common/nodestateupdater.h
+++ b/storage/src/vespa/storage/common/nodestateupdater.h
@@ -29,7 +29,7 @@
namespace storage {
-class ClusterStateBundle;
+namespace lib { class ClusterStateBundle; }
struct StateListener {
virtual ~StateListener() {}
@@ -43,7 +43,7 @@ struct NodeStateUpdater {
virtual lib::NodeState::CSP getReportedNodeState() const = 0;
virtual lib::NodeState::CSP getCurrentNodeState() const = 0;
- virtual std::shared_ptr<const ClusterStateBundle> getClusterStateBundle() const = 0;
+ virtual std::shared_ptr<const lib::ClusterStateBundle> getClusterStateBundle() const = 0;
virtual void addStateListener(StateListener&) = 0;
virtual void removeStateListener(StateListener&) = 0;
diff --git a/storage/src/vespa/storage/frameworkimpl/component/distributorcomponentregisterimpl.cpp b/storage/src/vespa/storage/frameworkimpl/component/distributorcomponentregisterimpl.cpp
index cf290c78acf..439bc9e078c 100644
--- a/storage/src/vespa/storage/frameworkimpl/component/distributorcomponentregisterimpl.cpp
+++ b/storage/src/vespa/storage/frameworkimpl/component/distributorcomponentregisterimpl.cpp
@@ -2,7 +2,7 @@
#include "distributorcomponentregisterimpl.h"
#include <vespa/vdslib/distribution/idealnodecalculatorimpl.h>
#include <vespa/vespalib/util/exceptions.h>
-#include <vespa/storage/common/cluster_state_bundle.h>
+#include <vespa/vdslib/state/cluster_state_bundle.h>
namespace storage {
diff --git a/storage/src/vespa/storage/persistence/bucketownershipnotifier.cpp b/storage/src/vespa/storage/persistence/bucketownershipnotifier.cpp
index 5e561951260..a0f05a70f4e 100644
--- a/storage/src/vespa/storage/persistence/bucketownershipnotifier.cpp
+++ b/storage/src/vespa/storage/persistence/bucketownershipnotifier.cpp
@@ -4,7 +4,7 @@
#include <vespa/storage/common/nodestateupdater.h>
#include <vespa/storage/common/bucketoperationlogger.h>
#include <vespa/storage/common/content_bucket_space_repo.h>
-#include <vespa/storage/common/cluster_state_bundle.h>
+#include <vespa/vdslib/state/cluster_state_bundle.h>
#include <vespa/storageapi/message/bucket.h>
#include <vespa/vdslib/distribution/distribution.h>
#include <vespa/vespalib/util/backtrace.h>
diff --git a/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp b/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp
index 2773c19eaa1..311dc52767d 100644
--- a/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp
+++ b/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp
@@ -6,7 +6,7 @@
#include <vespa/storage/common/bucketmessages.h>
#include <vespa/storage/common/bucketoperationlogger.h>
#include <vespa/storage/common/content_bucket_space_repo.h>
-#include <vespa/storage/common/cluster_state_bundle.h>
+#include <vespa/vdslib/state/cluster_state_bundle.h>
#include <vespa/storage/common/messagebucket.h>
#include <vespa/storage/config/config-stor-server.h>
#include <vespa/storage/persistence/bucketownershipnotifier.h>
diff --git a/storage/src/vespa/storage/storageserver/bouncer.cpp b/storage/src/vespa/storage/storageserver/bouncer.cpp
index af274c9b3e6..72edbfd095e 100644
--- a/storage/src/vespa/storage/storageserver/bouncer.cpp
+++ b/storage/src/vespa/storage/storageserver/bouncer.cpp
@@ -2,7 +2,7 @@
#include "bouncer.h"
#include "bouncer_metrics.h"
-#include <vespa/storage/common/cluster_state_bundle.h>
+#include <vespa/vdslib/state/cluster_state_bundle.h>
#include <vespa/storageapi/message/state.h>
#include <vespa/storageapi/message/persistence.h>
#include <vespa/config/subscription/configuri.h>
diff --git a/storage/src/vespa/storage/storageserver/changedbucketownershiphandler.cpp b/storage/src/vespa/storage/storageserver/changedbucketownershiphandler.cpp
index cd7be21a369..7cf42af841d 100644
--- a/storage/src/vespa/storage/storageserver/changedbucketownershiphandler.cpp
+++ b/storage/src/vespa/storage/storageserver/changedbucketownershiphandler.cpp
@@ -3,7 +3,7 @@
#include "changedbucketownershiphandler.h"
#include <vespa/storageapi/message/state.h>
#include <vespa/storage/bucketdb/storbucketdb.h>
-#include <vespa/storage/common/cluster_state_bundle.h>
+#include <vespa/vdslib/state/cluster_state_bundle.h>
#include <vespa/storage/common/messagebucket.h>
#include <vespa/storage/common/nodestateupdater.h>
#include <vespa/storage/common/content_bucket_space_repo.h>
diff --git a/storage/src/vespa/storage/storageserver/fnetlistener.cpp b/storage/src/vespa/storage/storageserver/fnetlistener.cpp
index f53af6dc225..7a0711c9f7c 100644
--- a/storage/src/vespa/storage/storageserver/fnetlistener.cpp
+++ b/storage/src/vespa/storage/storageserver/fnetlistener.cpp
@@ -152,7 +152,7 @@ FNetListener::RPC_setSystemState2(FRT_RPCRequest *req)
req->GetParams()->GetValue(0)._string._len);
lib::ClusterState systemState(systemStateStr);
- auto cmd(std::make_shared<api::SetSystemStateCommand>(systemState));
+ auto cmd(std::make_shared<api::SetSystemStateCommand>(lib::ClusterStateBundle(systemState)));
cmd->setPriority(api::StorageMessage::VERYHIGH);
// Create a request object to avoid needing a separate transport type
diff --git a/storage/src/vespa/storage/storageserver/mergethrottler.cpp b/storage/src/vespa/storage/storageserver/mergethrottler.cpp
index 73fa61e9fb7..a15b1b98d63 100644
--- a/storage/src/vespa/storage/storageserver/mergethrottler.cpp
+++ b/storage/src/vespa/storage/storageserver/mergethrottler.cpp
@@ -1,7 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "mergethrottler.h"
-#include <vespa/storage/common/cluster_state_bundle.h>
+#include <vespa/vdslib/state/cluster_state_bundle.h>
#include <vespa/storage/common/nodestateupdater.h>
#include <vespa/storage/persistence/messages.h>
#include <vespa/messagebus/message.h>
diff --git a/storage/src/vespa/storage/storageserver/statemanager.cpp b/storage/src/vespa/storage/storageserver/statemanager.cpp
index 1908eab96ec..11ca0bcc9ae 100644
--- a/storage/src/vespa/storage/storageserver/statemanager.cpp
+++ b/storage/src/vespa/storage/storageserver/statemanager.cpp
@@ -9,7 +9,7 @@
#include <vespa/storageapi/messageapi/storagemessage.h>
#include <vespa/storage/storageserver/storagemetricsset.h>
#include <vespa/storage/common/bucketoperationlogger.h>
-#include <vespa/storage/common/cluster_state_bundle.h>
+#include <vespa/vdslib/state/cluster_state_bundle.h>
#include <sys/types.h>
#include <unistd.h>
#include <vespa/vespalib/util/stringfmt.h>
@@ -192,7 +192,7 @@ StateManager::getCurrentNodeState() const
(_systemState->getBaselineClusterState()->getNodeState(thisNode()));
}
-std::shared_ptr<const ClusterStateBundle>
+std::shared_ptr<const lib::ClusterStateBundle>
StateManager::getClusterStateBundle() const
{
vespalib::LockGuard lock(_stateLock);
diff --git a/storage/src/vespa/storage/storageserver/statemanager.h b/storage/src/vespa/storage/storageserver/statemanager.h
index 8d3e4d75a88..9f5c60b42aa 100644
--- a/storage/src/vespa/storage/storageserver/statemanager.h
+++ b/storage/src/vespa/storage/storageserver/statemanager.h
@@ -33,7 +33,7 @@ namespace metrics {
namespace storage {
-class ClusterStateBundle;
+namespace lib { class ClusterStateBundle; }
class StateManager : public NodeStateUpdater,
public StorageLink,
@@ -50,6 +50,7 @@ class StateManager : public NodeStateUpdater,
std::atomic<bool> _notifyingListeners;
std::shared_ptr<lib::NodeState> _nodeState;
std::shared_ptr<lib::NodeState> _nextNodeState;
+ using ClusterStateBundle = lib::ClusterStateBundle;
std::shared_ptr<const ClusterStateBundle> _systemState;
std::shared_ptr<const ClusterStateBundle> _nextSystemState;
std::list<StateListener*> _stateListeners;
diff --git a/storageapi/src/vespa/storageapi/message/state.cpp b/storageapi/src/vespa/storageapi/message/state.cpp
index b128e8f6485..efa9a45764f 100644
--- a/storageapi/src/vespa/storageapi/message/state.cpp
+++ b/storageapi/src/vespa/storageapi/message/state.cpp
@@ -2,6 +2,7 @@
#include "state.h"
#include <vespa/storageapi/messageapi/storagemessage.h>
+#include <vespa/vdslib/state/clusterstate.h>
#include <ostream>
namespace storage {
@@ -61,6 +62,12 @@ GetNodeStateReply::print(std::ostream& out, bool verbose,
}
}
+SetSystemStateCommand::SetSystemStateCommand(const lib::ClusterStateBundle& state)
+ : StorageCommand(MessageType::SETSYSTEMSTATE),
+ _state(state)
+{
+}
+
SetSystemStateCommand::SetSystemStateCommand(const lib::ClusterState& state)
: StorageCommand(MessageType::SETSYSTEMSTATE),
_state(state)
@@ -71,7 +78,7 @@ void
SetSystemStateCommand::print(std::ostream& out, bool verbose,
const std::string& indent) const
{
- out << "SetSystemStateCommand(" << _state << ")";
+ out << "SetSystemStateCommand(" << *_state.getBaselineClusterState() << ")";
if (verbose) {
out << " : ";
StorageCommand::print(out, verbose, indent);
@@ -80,7 +87,7 @@ SetSystemStateCommand::print(std::ostream& out, bool verbose,
SetSystemStateReply::SetSystemStateReply(const SetSystemStateCommand& cmd)
: StorageReply(cmd),
- _state(cmd.getSystemState())
+ _state(cmd.getClusterStateBundle())
{
}
diff --git a/storageapi/src/vespa/storageapi/message/state.h b/storageapi/src/vespa/storageapi/message/state.h
index 746d92fce6b..4e5ad92b259 100644
--- a/storageapi/src/vespa/storageapi/message/state.h
+++ b/storageapi/src/vespa/storageapi/message/state.h
@@ -4,7 +4,8 @@
#include <vespa/storageapi/messageapi/storagecommand.h>
#include <vespa/storageapi/messageapi/storagereply.h>
-#include <vespa/vdslib/state/clusterstate.h>
+#include <vespa/vdslib/state/nodestate.h>
+#include <vespa/vdslib/state/cluster_state_bundle.h>
namespace storage::api {
@@ -60,11 +61,13 @@ public:
* put/get/remove etx)
*/
class SetSystemStateCommand : public StorageCommand {
- lib::ClusterState _state;
+ lib::ClusterStateBundle _state;
public:
- explicit SetSystemStateCommand(const lib::ClusterState&);
- const lib::ClusterState& getSystemState() const { return _state; }
+ explicit SetSystemStateCommand(const lib::ClusterStateBundle &state);
+ explicit SetSystemStateCommand(const lib::ClusterState &state);
+ const lib::ClusterState& getSystemState() const { return *_state.getBaselineClusterState(); }
+ const lib::ClusterStateBundle& getClusterStateBundle() const { return _state; }
void print(std::ostream& out, bool verbose, const std::string& indent) const override;
DECLARE_STORAGECOMMAND(SetSystemStateCommand, onSetSystemState)
@@ -77,13 +80,14 @@ public:
* @brief Reply received after a SetSystemStateCommand.
*/
class SetSystemStateReply : public StorageReply {
- lib::ClusterState _state;
+ lib::ClusterStateBundle _state;
public:
explicit SetSystemStateReply(const SetSystemStateCommand& cmd);
// Not serialized. Available locally
- const lib::ClusterState& getSystemState() const { return _state; }
+ const lib::ClusterState& getSystemState() const { return *_state.getBaselineClusterState(); }
+ const lib::ClusterStateBundle& getClusterStateBundle() const { return _state; }
void print(std::ostream& out, bool verbose, const std::string& indent) const override;
DECLARE_STORAGEREPLY(SetSystemStateReply, onSetSystemStateReply)
diff --git a/vdslib/src/vespa/vdslib/state/CMakeLists.txt b/vdslib/src/vespa/vdslib/state/CMakeLists.txt
index 24402526c85..620e86c2677 100644
--- a/vdslib/src/vespa/vdslib/state/CMakeLists.txt
+++ b/vdslib/src/vespa/vdslib/state/CMakeLists.txt
@@ -7,5 +7,6 @@ vespa_add_library(vdslib_state OBJECT
diskstate.cpp
nodestate.cpp
clusterstate.cpp
+ cluster_state_bundle.cpp
DEPENDS
)
diff --git a/storage/src/vespa/storage/common/cluster_state_bundle.cpp b/vdslib/src/vespa/vdslib/state/cluster_state_bundle.cpp
index 1793c74d378..c55f1aadd06 100644
--- a/storage/src/vespa/storage/common/cluster_state_bundle.cpp
+++ b/vdslib/src/vespa/vdslib/state/cluster_state_bundle.cpp
@@ -1,9 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "cluster_state_bundle.h"
-#include <vespa/vdslib/state/clusterstate.h>
+#include "clusterstate.h"
-namespace storage {
+namespace storage::lib {
ClusterStateBundle::ClusterStateBundle(const ClusterState &baselineClusterState)
: _baselineClusterState(std::make_shared<const ClusterState>(baselineClusterState))
diff --git a/storage/src/vespa/storage/common/cluster_state_bundle.h b/vdslib/src/vespa/vdslib/state/cluster_state_bundle.h
index af4a12a8b3c..c54df1d1952 100644
--- a/storage/src/vespa/storage/common/cluster_state_bundle.h
+++ b/vdslib/src/vespa/vdslib/state/cluster_state_bundle.h
@@ -4,9 +4,9 @@
#include <vespa/document/bucket/bucketspace.h>
-namespace storage {
+namespace storage::lib {
-namespace lib { class ClusterState; }
+class ClusterState;
/**
* Class representing the baseline cluster state and the derived cluster
@@ -14,10 +14,9 @@ namespace lib { class ClusterState; }
*/
class ClusterStateBundle
{
- using ClusterState = lib::ClusterState;
std::shared_ptr<const ClusterState> _baselineClusterState;
public:
- ClusterStateBundle(const ClusterState &baselineClusterState);
+ explicit ClusterStateBundle(const ClusterState &baselineClusterState);
~ClusterStateBundle();
const std::shared_ptr<const ClusterState> &getBaselineClusterState() const;
const std::shared_ptr<const ClusterState> &getDerivedClusterState(document::BucketSpace bucketSpace) const;
diff --git a/vespa-athenz/CMakeLists.txt b/vespa-athenz/CMakeLists.txt
new file mode 100644
index 00000000000..bb5a1f5b6de
--- /dev/null
+++ b/vespa-athenz/CMakeLists.txt
@@ -0,0 +1,2 @@
+# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+install_fat_java_artifact(vespa-athenz)
diff --git a/vespa-athenz/pom.xml b/vespa-athenz/pom.xml
index 5312594472f..31e56f76dd2 100644
--- a/vespa-athenz/pom.xml
+++ b/vespa-athenz/pom.xml
@@ -41,7 +41,12 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
-
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>testutil</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ </dependency>
<!-- compile -->
<dependency>
@@ -110,31 +115,6 @@
<groupId>com.yahoo.vespa</groupId>
<artifactId>bundle-plugin</artifactId>
<extensions>true</extensions>
- <configuration>
- <useCommonAssemblyIds>false</useCommonAssemblyIds>
- </configuration>
- </plugin>
- <plugin>
- <groupId>org.codehaus.mojo</groupId>
- <artifactId>build-helper-maven-plugin</artifactId>
- <executions>
- <execution>
- <id>attach-artifacts</id>
- <phase>package</phase>
- <goals>
- <goal>attach-artifact</goal>
- </goals>
- <configuration>
- <artifacts>
- <artifact>
- <file>target/${project.artifactId}-deploy.jar</file>
- <type>jar</type>
- <classifier>deploy</classifier>
- </artifact>
- </artifacts>
- </configuration>
- </execution>
- </executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzCredentials.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentials.java
index 36c1aee49e0..c5dce1c5b1d 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzCredentials.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentials.java
@@ -1,5 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.athenz.identityprovider;
+package com.yahoo.vespa.athenz.identityprovider;
import java.security.KeyPair;
import java.security.cert.X509Certificate;
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzCredentialsService.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentialsService.java
index 4072568d9d2..dd816929bfb 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzCredentialsService.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentialsService.java
@@ -1,5 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.athenz.identityprovider;
+package com.yahoo.vespa.athenz.identityprovider;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.container.core.identity.IdentityConfig;
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzIdentityProviderImpl.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImpl.java
index 18f90ce545f..95113e1b0b1 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzIdentityProviderImpl.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImpl.java
@@ -1,5 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.athenz.identityprovider;
+package com.yahoo.vespa.athenz.identityprovider;
import com.google.inject.Inject;
import com.yahoo.component.AbstractComponent;
@@ -8,16 +8,15 @@ import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider;
import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException;
import com.yahoo.jdisc.Metric;
import com.yahoo.log.LogLevel;
-import com.yahoo.vespa.defaults.Defaults;
import com.yahoo.vespa.athenz.api.AthenzIdentityCertificate;
import com.yahoo.vespa.athenz.tls.AthenzSslContextBuilder;
+import com.yahoo.vespa.defaults.Defaults;
import javax.net.ssl.SSLContext;
import java.io.File;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
-import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzService.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzService.java
index c9e3809ea96..18576ab9bab 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzService.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzService.java
@@ -1,5 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.athenz.identityprovider;
+package com.yahoo.vespa.athenz.identityprovider;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/CryptoUtils.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/CryptoUtils.java
index 6a766e7c49d..6e74d3bc8b1 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/CryptoUtils.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/CryptoUtils.java
@@ -1,5 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.athenz.identityprovider;
+package com.yahoo.vespa.athenz.identityprovider;
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers;
import org.bouncycastle.asn1.x509.Extension;
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/IdentityDocumentService.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/IdentityDocumentService.java
index 8a9137a491d..4e88234d5de 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/IdentityDocumentService.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/IdentityDocumentService.java
@@ -1,5 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.athenz.identityprovider;
+package com.yahoo.vespa.athenz.identityprovider;
import com.yahoo.vespa.defaults.Defaults;
import org.apache.http.client.methods.CloseableHttpResponse;
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceIdentity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceIdentity.java
index d6e986959cb..b90ce56ca7e 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceIdentity.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceIdentity.java
@@ -1,5 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.athenz.identityprovider;
+package com.yahoo.vespa.athenz.identityprovider;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceRefreshInformation.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceRefreshInformation.java
index d0c22d1d0d2..c627363c0f5 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceRefreshInformation.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceRefreshInformation.java
@@ -1,5 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.athenz.identityprovider;
+package com.yahoo.vespa.athenz.identityprovider;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceRegisterInformation.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceRegisterInformation.java
index dd9f164fef1..69ddb72b8b8 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceRegisterInformation.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceRegisterInformation.java
@@ -1,5 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.athenz.identityprovider;
+package com.yahoo.vespa.athenz.identityprovider;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/SignedIdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/SignedIdentityDocument.java
index 7bbd49c953f..c3b073765ac 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/SignedIdentityDocument.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/SignedIdentityDocument.java
@@ -1,5 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.athenz.identityprovider;
+package com.yahoo.vespa.athenz.identityprovider;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/package-info.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/package-info.java
index 1b4842327dd..f23ea9406b3 100644
--- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/package-info.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/package-info.java
@@ -3,6 +3,6 @@
* @author mortent
*/
@ExportPackage
-package com.yahoo.vespa.hosted.athenz.identityprovider;
+package com.yahoo.vespa.athenz.identityprovider;
import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzIdentityProviderImplTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImplTest.java
index 3a506a39c43..d9dbd73a94e 100644
--- a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzIdentityProviderImplTest.java
+++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImplTest.java
@@ -1,13 +1,13 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.athenz.identityprovider;
+package com.yahoo.vespa.athenz.identityprovider;
import com.yahoo.container.core.identity.IdentityConfig;
import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider;
import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException;
-import com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.RunnableWithTag;
-import com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.Scheduler;
import com.yahoo.jdisc.Metric;
import com.yahoo.test.ManualClock;
+import com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.RunnableWithTag;
+import com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.Scheduler;
import org.junit.Test;
import java.security.cert.X509Certificate;
@@ -19,21 +19,13 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
-import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-
-import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.INITIAL_BACKOFF_DELAY;
-import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.INITIAL_WAIT_NTOKEN;
-import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.MAX_REGISTER_BACKOFF_DELAY;
-import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.METRICS_UPDATER_TAG;
-import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.REDUCED_UPDATE_PERIOD;
-import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.REGISTER_INSTANCE_TAG;
-import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.TIMEOUT_INITIAL_WAIT_TAG;
-import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.UPDATE_CREDENTIALS_TAG;
-import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.UPDATE_PERIOD;
+
+import static com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.METRICS_UPDATER_TAG;
+import static com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.REDUCED_UPDATE_PERIOD;
+import static com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.UPDATE_CREDENTIALS_TAG;
+import static com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.UPDATE_PERIOD;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyString;
diff --git a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/identityprovider/CryptoUtilsTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/CryptoUtilsTest.java
index 0412b9071dd..353c5d3c504 100644
--- a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/identityprovider/CryptoUtilsTest.java
+++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/CryptoUtilsTest.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.hosted.athenz.identityprovider;
+package com.yahoo.vespa.athenz.identityprovider;
import org.bouncycastle.pkcs.PKCS10CertificationRequest;
import org.junit.Test;
diff --git a/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java
new file mode 100644
index 00000000000..e0e4a0828a9
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java
@@ -0,0 +1,33 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.lang;
+
+/**
+ * A mutable long
+ *
+ * @author bratseth
+ */
+public class MutableLong {
+
+ private long value;
+
+ public MutableLong(long value) {
+ this.value = value;
+ }
+
+ public long get() { return value; }
+
+ public void set(long value) { this.value = value; }
+
+ /** Adds the increment to the current value and returns the resulting value */
+ public long add(long increment) {
+ value += increment;
+ return value;
+ }
+
+ /** Adds the increment to the current value and returns the resulting value */
+ public long subtract(long increment) {
+ value -= increment;
+ return value;
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 14cd3e70866..0176dac6821 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -77,6 +77,13 @@ public class TensorType {
return Optional.empty();
}
+ /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */
+ public Optional<Long> sizeOfDimension(String dimension) {
+ Optional<Dimension> d = dimension(dimension);
+ if ( ! d.isPresent()) return Optional.empty();
+ return d.get().size();
+ }
+
/**
* Returns whether this type can be assigned to the given type,
* i.e if the given type is a generalization of this type.
@@ -207,7 +214,7 @@ public class TensorType {
/** Returns a copy of this with the name set to the given name */
public abstract Dimension withName(String name);
- /** Returns true if this is an indexed bound or unboun type */
+ /** Returns true if this is an indexed bound or unbound type */
public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; }
/**
@@ -254,6 +261,14 @@ public class TensorType {
return new IndexedBoundDimension(name, size);
}
+ public static Dimension indexed(String name) {
+ return new IndexedUnboundDimension(name);
+ }
+
+ public static Dimension mapped(String name) {
+ return new MappedDimension(name);
+ }
+
}
public static class IndexedBoundDimension extends TensorType.Dimension {
@@ -367,6 +382,15 @@ public class TensorType {
addDimensionsOf(type);
}
+ /**
+ * Creates a builder from the given dimensions.
+ */
+ public Builder(Iterable<Dimension> dimensions) {
+ for (TensorType.Dimension dimension : dimensions) {
+ dimension(dimension);
+ }
+ }
+
private static final boolean supportsMixedTypes = false;
private void addDimensionsOf(TensorType type) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
index 3fb94f1251b..8a969180113 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
@@ -10,7 +10,7 @@ import com.yahoo.tensor.Tensor;
* @author bratseth
*/
@Beta
-public interface EvaluationContext extends TypeContext {
+public interface EvaluationContext<NAMETYPE extends TypeContext.Name> extends TypeContext<NAMETYPE> {
/** Returns the tensor bound to this name, or null if none */
Tensor getTensor(String name);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
index 9fe6b7d053f..b9394da31e3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
@@ -11,17 +11,20 @@ import java.util.HashMap;
* @author bratseth
*/
@Beta
-public class MapEvaluationContext implements EvaluationContext {
+public class MapEvaluationContext implements EvaluationContext<TypeContext.Name> {
private final java.util.Map<String, Tensor> bindings = new HashMap<>();
- static MapEvaluationContext empty() { return new MapEvaluationContext(); }
-
public void put(String name, Tensor tensor) { bindings.put(name, tensor); }
@Override
public TensorType getType(String name) {
- Tensor tensor = bindings.get(name);
+ return getType(new Name(name));
+ }
+
+ @Override
+ public TensorType getType(Name name) {
+ Tensor tensor = bindings.get(name.toString());
if (tensor == null) return null;
return tensor.type();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
index 760a225efdf..ff2e6318b37 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
@@ -8,7 +8,7 @@ import com.yahoo.tensor.TensorType;
*
* @author bratseth
*/
-public interface TypeContext {
+public interface TypeContext<NAMETYPE extends TypeContext.Name> {
/**
* Returns the type of the tensor with this name.
@@ -16,6 +16,39 @@ public interface TypeContext {
* @return returns the type of the tensor which will be returned by calling getTensor(name)
* or null if getTensor will return null.
*/
+ TensorType getType(NAMETYPE name);
+
+ /**
+ * Returns the type of the tensor with this name by converting from a string name.
+ *
+ * @return returns the type of the tensor which will be returned by calling getTensor(name)
+ * or null if getTensor will return null.
+ */
TensorType getType(String name);
+ /** A name which is just a string. Names are value objects. */
+ class Name {
+
+ private final String name;
+
+ public Name(String name) {
+ this.name = name;
+ }
+
+ @Override
+ public String toString() { return name; }
+
+ @Override
+ public int hashCode() { return name.hashCode(); }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == this) return true;
+ if ( ! (other instanceof Name)) return false;
+ return ((Name)other).name.equals(this.name);
+ }
+
+ }
+
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index 34beb465d4c..acb2363cba4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -44,7 +44,7 @@ public class VariableTensor extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public TensorType type(TypeContext context) {
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
TensorType givenType = context.getType(name);
if (givenType == null) return null;
verifyType(givenType);
@@ -52,7 +52,7 @@ public class VariableTensor extends PrimitiveTensorFunction {
}
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor tensor = context.getTensor(name);
if (tensor == null) return null;
verifyType(tensor.type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 2109b730e1a..bfc0938abcc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -18,10 +18,14 @@ public abstract class CompositeTensorFunction extends TensorFunction {
/** Finds the type this produces by first converting it to a primitive function */
@Override
- public final TensorType type(TypeContext context) { return toPrimitive().type(context); }
+ public final <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ return toPrimitive().type(context);
+ }
/** Evaluates this by first converting it to a primitive function */
@Override
- public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); }
+ public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ return toPrimitive().evaluate(context);
+ }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index c77ed1c0526..13e7c136feb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -3,6 +3,8 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.yahoo.lang.MutableInteger;
+import com.yahoo.lang.MutableLong;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
@@ -60,21 +62,35 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public TensorType type(TypeContext context) {
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
return type(argumentA.type(context), argumentB.type(context));
}
/** Returns the type resulting from concatenating a and b */
private TensorType type(TensorType a, TensorType b) {
+ // TODO: Fail if concat dimension is present but not indexed in a or b
TensorType.Builder builder = new TensorType.Builder(a, b);
- if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size
- builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() +
- b.dimension(dimension).get().size().get()));
+ if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) {
+ builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) +
+ b.sizeOfDimension(dimension).orElse(1L)));
+ /*
+ MutableLong concatSize = new MutableLong(0);
+ a.sizeOfDimension(dimension).ifPresent(concatSize::add);
+ b.sizeOfDimension(dimension).ifPresent(concatSize::add);
+ builder.set(TensorType.Dimension.indexed(dimension, concatSize.get()));
+ */
+ }
return builder.build();
}
+ /** Returns true if this dimension is present and unbound */
+ private boolean unboundIn(TensorType type, String dimensionName) {
+ Optional<TensorType.Dimension> dimension = type.dimension(dimensionName);
+ return dimension.isPresent() && ! dimension.get().size().isPresent();
+ }
+
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
a = ensureIndexedDimension(dimension, a);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 50b479da168..a43de297b9a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -42,10 +42,10 @@ public class ConstantTensor extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public TensorType type(TypeContext context) { return constant.type(); }
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); }
@Override
- public Tensor evaluate(EvaluationContext context) { return constant; }
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; }
@Override
public String toString(ToStringContext context) { return constant.toString(); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index e70d1de3db7..edfa8253eb9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -61,10 +61,10 @@ public class Generate extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public TensorType type(TypeContext context) { return type; }
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; }
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type);
IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type));
for (int i = 0; i < indexes.size(); i++) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 7812c985091..50b0e706a43 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -95,12 +95,12 @@ public class Join extends PrimitiveTensorFunction {
}
@Override
- public TensorType type(TypeContext context) {
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build();
}
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
@@ -251,7 +251,7 @@ public class Join extends PrimitiveTensorFunction {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, false, builder);
- joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder);
+// joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder);
return builder.build();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 53504868ff2..4a338e5501e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -53,12 +53,12 @@ public class Map extends PrimitiveTensorFunction {
}
@Override
- public TensorType type(TypeContext context) {
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
return argument.type(context);
}
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor argument = argument().evaluate(context);
Tensor.Builder builder = Tensor.Builder.of(argument.type());
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 76a938b9fe2..e045effbe7e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -101,11 +101,12 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public TensorType type(TypeContext context) {
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
return type(argument.type(context));
}
private TensorType type(TensorType argumentType) {
+ if (dimensions.isEmpty()) return TensorType.empty; // means reduce all
TensorType.Builder builder = new TensorType.Builder();
for (TensorType.Dimension dimension : argumentType.dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
@@ -114,7 +115,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor argument = this.argument.evaluate(context);
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index de3d2be265a..af4492ca1e4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -72,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public TensorType type(TypeContext context) {
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
return type(argument.type(context));
}
@@ -84,7 +84,7 @@ public class Rename extends PrimitiveTensorFunction {
}
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor tensor = argument.evaluate(context);
TensorType renamedType = type(tensor.type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 78ab09c7820..e805e9d87bb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -43,14 +43,14 @@ public abstract class TensorFunction {
*
* @param context a context which must be passed to all nexted functions when evaluating
*/
- public abstract Tensor evaluate(EvaluationContext context);
+ public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context);
/**
* Returns the type of the tensor this produces given the input types in the context
*
* @param context a context which must be passed to all nexted functions when evaluating
*/
- public abstract TensorType type(TypeContext context);
+ public abstract <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context);
/** Evaluate with no context */
public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); }
@@ -58,7 +58,7 @@ public abstract class TensorFunction {
/**
* Return a string representation of this context.
*
- * @param context a context which must be passed to all nexted functions when requesting the string value
+ * @param context a context which must be passed to all nested functions when requesting the string value
*/
public abstract String toString(ToStringContext context);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
index 7e1f292eb7b..eafa5c4addf 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
@@ -2,6 +2,9 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -16,51 +19,98 @@ public class ConcatTestCase {
public void testConcatNumbers() {
Tensor a = Tensor.from("{1}");
Tensor b = Tensor.from("{2}");
- assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:2, {x:1}:1 }"), b.concat(a, "x"));
+ assertConcat("tensor(x[2]):{ {x:0}:1, {x:1}:2 }", a, b, "x");
+ assertConcat("tensor(x[2]):{ {x:0}:2, {x:1}:1 }", b, a , "x");
}
@Test
public void testConcatEqualShapes() {
- Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2, {x:2}:3 }");
- Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
- assertEquals(Tensor.from("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " +
- "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }"), a.concat(b, "y"));
+ Tensor a = Tensor.from("tensor(x[3]):{ {x:0}:1, {x:1}:2, {x:2}:3 }");
+ Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
+ assertConcat("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }", a, b, "x");
+ assertConcat("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " +
+ "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }",
+ a, b, "y");
}
@Test
public void testConcatNumberAndVector() {
Tensor a = Tensor.from("{1}");
+ Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:2, {x:1}:3, {x:2}:4 }");
+ assertConcat("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }", a, b, "x");
+ assertConcat("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " +
+ "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }",
+ a, b, "y");
+ }
+
+ @Test
+ public void testConcatNumberAndVectorUnbound() {
+ Tensor a = Tensor.from("{1}");
Tensor b = Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:3, {x:2}:4 }");
- assertEquals(Tensor.from("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " +
- "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }"), a.concat(b, "y"));
+ assertConcat("tensor(x[])","tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }", a, b, "x");
+ assertConcat("tensor(x[],y[2])", "tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " +
+ "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }",
+ a, b, "y");
}
@Test
public void testUnequalSizesSameDimension() {
+ Tensor a = Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }");
+ Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
+ assertConcat("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }", a, b, "x");
+ assertConcat("tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }", a, b,"y");
+ }
+
+ @Test
+ public void testUnequalSizesSameDimensionUnbound() {
Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }");
Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
- assertEquals(Tensor.from("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }"), a.concat(b, "y"));
+ assertConcat("tensor(x[])", "tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }", a, b, "x");
+ assertConcat("tensor(x[],y[2])", "tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }", a, b,"y");
}
@Test
public void testUnequalEqualSizesDifferentDimension() {
+ Tensor a = Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }");
+ Tensor b = Tensor.from("tensor(y[3]):{ {y:0}:4, {y:1}:5, {y:2}:6 }");
+ assertConcat("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}", a, b, "x");
+ assertConcat("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y");
+ assertConcat("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}", a, b, "z");
+ }
+
+ @Test
+ public void testUnequalEqualSizesDifferentDimensionOneUnbound() {
Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }");
- Tensor b = Tensor.from("tensor(y[]):{ {y:0}:4, {y:1}:5, {y:2}:6 }");
- assertEquals(Tensor.from("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y"));
- assertEquals(Tensor.from("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}"), a.concat(b, "z"));
+ Tensor b = Tensor.from("tensor(y[3]):{ {y:0}:4, {y:1}:5, {y:2}:6 }");
+ assertConcat("tensor(x[],y[3])", "tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}", a, b, "x");
+ assertConcat("tensor(x[],y[4])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y");
+ assertConcat("tensor(x[],y[3],z[2])", "tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}", a, b, "z");
}
@Test
public void testDimensionsubset() {
Tensor a = Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:3, {x:1,y:1}:4 }");
Tensor b = Tensor.from("tensor(y[2]):{ {y:0}:5, {y:1}:6 }");
- assertEquals(Tensor.from("tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y"));
+ assertConcat("tensor(x[],y[])", "tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}", a, b, "x");
+ assertConcat("tensor(x[],y[])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y");
+ }
+
+ private void assertConcat(String expected, Tensor a, Tensor b, String dimension) {
+ assertConcat(null, expected, a, b, dimension);
+ }
+
+ private void assertConcat(String expectedType, String expected, Tensor a, Tensor b, String dimension) {
+ Tensor expectedAsTensor = Tensor.from(expected);
+ TensorType inferredType = new Concat(new ConstantTensor(a), new ConstantTensor(b), dimension)
+ .type(new MapEvaluationContext());
+ Tensor result = a.concat(b, dimension);
+
+ if (expectedType != null)
+ assertEquals(TensorType.fromSpec(expectedType), inferredType);
+ else
+ assertEquals(expectedAsTensor.type(), inferredType);
+
+ assertEquals(expectedAsTensor, result);
}
}