diff options
author | Harald Musum <musum@yahooinc.com> | 2024-02-02 11:08:43 +0100 |
---|---|---|
committer | Harald Musum <musum@yahooinc.com> | 2024-02-02 11:08:43 +0100 |
commit | b67a76968bb7ea7ab3103c2123a85fbde8f7edc2 (patch) | |
tree | 39aee93644b7bcee17d2ddd5c412915f87048add | |
parent | d40968833cc2c798692fdcabe817b8341317aae7 (diff) | |
parent | 4604428930d005e6f84619de8f39b7668f4b1787 (diff) |
Merge branch 'master' into hmusum/change-default-flag-value-2
331 files changed, 8407 insertions, 4482 deletions
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8ce06014b5a..c5a50eb89df 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -11,7 +11,7 @@ This documents tells you what you need to know to contribute. ## Open development All work on Vespa happens directly on GitHub, -using the [GitHub flow model](https://docs.github.com/en/get-started/quickstart/github-flow). +using the [GitHub flow model](https://docs.github.com/en/get-started/using-github/github-flow). We release the master branch four times a week, and you should expect it to always work. The continuous build of Vespa is at [https://factory.vespa.oath.cloud](https://factory.vespa.oath.cloud). You can follow the fate of each commit there. diff --git a/client/go/go.mod b/client/go/go.mod index 3e721fe2a06..8699f3e9245 100644 --- a/client/go/go.mod +++ b/client/go/go.mod @@ -8,7 +8,7 @@ require ( github.com/fatih/color v1.16.0 // This is the most recent version compatible with Go 1.20. Upgrade when we upgrade our Go version github.com/go-json-experiment/json v0.0.0-20230324203220-04923b7a9528 - github.com/klauspost/compress v1.17.4 + github.com/klauspost/compress v1.17.5 github.com/mattn/go-colorable v0.1.13 github.com/mattn/go-isatty v0.0.20 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c diff --git a/client/go/go.sum b/client/go/go.sum index e2b1c85442d..fc5730a071d 100644 --- a/client/go/go.sum +++ b/client/go/go.sum @@ -20,6 +20,8 @@ github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4s github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.5 h1:d4vBd+7CHydUqpFBgUEKkSdtSugf9YFmSkvUYPquI5E= +github.com/klauspost/compress v1.17.5/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= diff --git a/client/go/internal/vespa/crypto.go b/client/go/internal/vespa/crypto.go index 9b4d776d97d..568d7a84d18 100644 --- a/client/go/internal/vespa/crypto.go +++ b/client/go/internal/vespa/crypto.go @@ -13,6 +13,7 @@ import ( "encoding/base64" "encoding/hex" "encoding/pem" + "errors" "fmt" "io" "math/big" @@ -220,3 +221,15 @@ func randomSerialNumber() (*big.Int, error) { serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) return rand.Int(rand.Reader, serialNumberLimit) } + +// isTLSAlert returns whether err contains a TLS alert error. +func isTLSAlert(err error) bool { + for ; err != nil; err = errors.Unwrap(err) { + // This is ugly, but alert types are currently not exposed: + // https://github.com/golang/go/issues/35234 + if fmt.Sprintf("%T", err) == "tls.alert" { + return true + } + } + return false +} diff --git a/client/go/internal/vespa/target.go b/client/go/internal/vespa/target.go index 90d1e1997da..ed3cb146eb1 100644 --- a/client/go/internal/vespa/target.go +++ b/client/go/internal/vespa/target.go @@ -153,7 +153,11 @@ func (s *Service) Do(request *http.Request, timeout time.Duration) (*http.Respon if err := s.CurlWriter.print(request, s.TLSOptions, timeout); err != nil { return nil, err } - return s.httpClient.Do(request, timeout) + resp, err := s.httpClient.Do(request, timeout) + if isTLSAlert(err) { + return nil, fmt.Errorf("%w: %s", errAuth, err) + } + return resp, err } // SetClient sets a custom HTTP client that this service should use. diff --git a/client/js/app/package.json b/client/js/app/package.json index e6273edbf75..1cc2432f88f 100644 --- a/client/js/app/package.json +++ b/client/js/app/package.json @@ -32,11 +32,11 @@ "eslint-plugin-react-hooks": "^4", "eslint-plugin-react-perf": "^3", "eslint-plugin-unused-imports": "^3", - "husky": "^8", + "husky": "^9.0.0", "jest": "^29", "lodash": "^4", "prettier": "3", - "pretty-quick": "^3", + "pretty-quick": "^4.0.0", "react-router-dom": "^6", "use-context-selector": "^1", "vite": "^5.0.5" diff --git a/client/js/app/yarn.lock b/client/js/app/yarn.lock index e9dc5bf25fe..b231fbcf61b 100644 --- a/client/js/app/yarn.lock +++ b/client/js/app/yarn.lock @@ -1311,70 +1311,70 @@ resolved "https://registry.yarnpkg.com/@remix-run/router/-/router-1.14.2.tgz#4d58f59908d9197ba3179310077f25c88e49ed17" integrity sha512-ACXpdMM9hmKZww21yEqWwiLws/UPLhNKvimN8RrYSqPSvB3ov7sLvAcfvaxePeLvccTQKGdkDIhLYApZVDFuKg== -"@rollup/rollup-android-arm-eabi@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.9.4.tgz#b1094962742c1a0349587040bc06185e2a667c9b" - integrity sha512-ub/SN3yWqIv5CWiAZPHVS1DloyZsJbtXmX4HxUTIpS0BHm9pW5iYBo2mIZi+hE3AeiTzHz33blwSnhdUo+9NpA== - -"@rollup/rollup-android-arm64@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.9.4.tgz#96eb86fb549e05b187f2ad06f51d191a23cb385a" - integrity sha512-ehcBrOR5XTl0W0t2WxfTyHCR/3Cq2jfb+I4W+Ch8Y9b5G+vbAecVv0Fx/J1QKktOrgUYsIKxWAKgIpvw56IFNA== - -"@rollup/rollup-darwin-arm64@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.9.4.tgz#2456630c007cc5905cb368acb9ff9fc04b2d37be" - integrity sha512-1fzh1lWExwSTWy8vJPnNbNM02WZDS8AW3McEOb7wW+nPChLKf3WG2aG7fhaUmfX5FKw9zhsF5+MBwArGyNM7NA== - -"@rollup/rollup-darwin-x64@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.9.4.tgz#97742214fc7dfd47a0f74efba6f5ae264e29c70c" - integrity sha512-Gc6cukkF38RcYQ6uPdiXi70JB0f29CwcQ7+r4QpfNpQFVHXRd0DfWFidoGxjSx1DwOETM97JPz1RXL5ISSB0pA== - -"@rollup/rollup-linux-arm-gnueabihf@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.9.4.tgz#cd933e61d6f689c9cdefde424beafbd92cfe58e2" - integrity sha512-g21RTeFzoTl8GxosHbnQZ0/JkuFIB13C3T7Y0HtKzOXmoHhewLbVTFBQZu+z5m9STH6FZ7L/oPgU4Nm5ErN2fw== - -"@rollup/rollup-linux-arm64-gnu@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.9.4.tgz#33b09bf462f1837afc1e02a1b352af6b510c78a6" - integrity sha512-TVYVWD/SYwWzGGnbfTkrNpdE4HON46orgMNHCivlXmlsSGQOx/OHHYiQcMIOx38/GWgwr/po2LBn7wypkWw/Mg== - -"@rollup/rollup-linux-arm64-musl@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.9.4.tgz#50257fb248832c2308064e3764a16273b6ee4615" - integrity sha512-XcKvuendwizYYhFxpvQ3xVpzje2HHImzg33wL9zvxtj77HvPStbSGI9czrdbfrf8DGMcNNReH9pVZv8qejAQ5A== - -"@rollup/rollup-linux-riscv64-gnu@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.9.4.tgz#09589e4e1a073cf56f6249b77eb6c9a8e9b613a8" - integrity sha512-LFHS/8Q+I9YA0yVETyjonMJ3UA+DczeBd/MqNEzsGSTdNvSJa1OJZcSH8GiXLvcizgp9AlHs2walqRcqzjOi3A== - -"@rollup/rollup-linux-x64-gnu@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.9.4.tgz#bd312bb5b5f02e54d15488605d15cfd3f90dda7c" - integrity sha512-dIYgo+j1+yfy81i0YVU5KnQrIJZE8ERomx17ReU4GREjGtDW4X+nvkBak2xAUpyqLs4eleDSj3RrV72fQos7zw== - -"@rollup/rollup-linux-x64-musl@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.9.4.tgz#25b3bede85d86438ce28cc642842d10d867d40e9" - integrity sha512-RoaYxjdHQ5TPjaPrLsfKqR3pakMr3JGqZ+jZM0zP2IkDtsGa4CqYaWSfQmZVgFUCgLrTnzX+cnHS3nfl+kB6ZQ== - -"@rollup/rollup-win32-arm64-msvc@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.9.4.tgz#95957067eb107f571da1d81939f017d37b4958d3" - integrity sha512-T8Q3XHV+Jjf5e49B4EAaLKV74BbX7/qYBRQ8Wop/+TyyU0k+vSjiLVSHNWdVd1goMjZcbhDmYZUYW5RFqkBNHQ== - -"@rollup/rollup-win32-ia32-msvc@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.9.4.tgz#71b6facad976db527863f698692c6964c0b6e10e" - integrity sha512-z+JQ7JirDUHAsMecVydnBPWLwJjbppU+7LZjffGf+Jvrxq+dVjIE7By163Sc9DKc3ADSU50qPVw0KonBS+a+HQ== - -"@rollup/rollup-win32-x64-msvc@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.9.4.tgz#16295ccae354707c9bc6842906bdeaad4f3ba7a5" - integrity sha512-LfdGXCV9rdEify1oxlN9eamvDSjv9md9ZVMAbNHA87xqIfFCxImxan9qZ8+Un54iK2nnqPlbnSi4R54ONtbWBw== +"@rollup/rollup-android-arm-eabi@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.9.5.tgz#b752b6c88a14ccfcbdf3f48c577ccc3a7f0e66b9" + integrity sha512-idWaG8xeSRCfRq9KpRysDHJ/rEHBEXcHuJ82XY0yYFIWnLMjZv9vF/7DOq8djQ2n3Lk6+3qfSH8AqlmHlmi1MA== + +"@rollup/rollup-android-arm64@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.9.5.tgz#33757c3a448b9ef77b6f6292d8b0ec45c87e9c1a" + integrity sha512-f14d7uhAMtsCGjAYwZGv6TwuS3IFaM4ZnGMUn3aCBgkcHAYErhV1Ad97WzBvS2o0aaDv4mVz+syiN0ElMyfBPg== + +"@rollup/rollup-darwin-arm64@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.9.5.tgz#5234ba62665a3f443143bc8bcea9df2cc58f55fb" + integrity sha512-ndoXeLx455FffL68OIUrVr89Xu1WLzAG4n65R8roDlCoYiQcGGg6MALvs2Ap9zs7AHg8mpHtMpwC8jBBjZrT/w== + +"@rollup/rollup-darwin-x64@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.9.5.tgz#981256c054d3247b83313724938d606798a919d1" + integrity sha512-UmElV1OY2m/1KEEqTlIjieKfVwRg0Zwg4PLgNf0s3glAHXBN99KLpw5A5lrSYCa1Kp63czTpVll2MAqbZYIHoA== + +"@rollup/rollup-linux-arm-gnueabihf@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.9.5.tgz#120678a5a2b3a283a548dbb4d337f9187a793560" + integrity sha512-Q0LcU61v92tQB6ae+udZvOyZ0wfpGojtAKrrpAaIqmJ7+psq4cMIhT/9lfV6UQIpeItnq/2QDROhNLo00lOD1g== + +"@rollup/rollup-linux-arm64-gnu@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.9.5.tgz#c99d857e2372ece544b6f60b85058ad259f64114" + integrity sha512-dkRscpM+RrR2Ee3eOQmRWFjmV/payHEOrjyq1VZegRUa5OrZJ2MAxBNs05bZuY0YCtpqETDy1Ix4i/hRqX98cA== + +"@rollup/rollup-linux-arm64-musl@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.9.5.tgz#3064060f568a5718c2a06858cd6e6d24f2ff8632" + integrity sha512-QaKFVOzzST2xzY4MAmiDmURagWLFh+zZtttuEnuNn19AiZ0T3fhPyjPPGwLNdiDT82ZE91hnfJsUiDwF9DClIQ== + +"@rollup/rollup-linux-riscv64-gnu@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.9.5.tgz#987d30b5d2b992fff07d055015991a57ff55fbad" + integrity sha512-HeGqmRJuyVg6/X6MpE2ur7GbymBPS8Np0S/vQFHDmocfORT+Zt76qu+69NUoxXzGqVP1pzaY6QIi0FJWLC3OPA== + +"@rollup/rollup-linux-x64-gnu@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.9.5.tgz#85946ee4d068bd12197aeeec2c6f679c94978a49" + integrity sha512-Dq1bqBdLaZ1Gb/l2e5/+o3B18+8TI9ANlA1SkejZqDgdU/jK/ThYaMPMJpVMMXy2uRHvGKbkz9vheVGdq3cJfA== + +"@rollup/rollup-linux-x64-musl@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.9.5.tgz#fe0b20f9749a60eb1df43d20effa96c756ddcbd4" + integrity sha512-ezyFUOwldYpj7AbkwyW9AJ203peub81CaAIVvckdkyH8EvhEIoKzaMFJj0G4qYJ5sw3BpqhFrsCc30t54HV8vg== + +"@rollup/rollup-win32-arm64-msvc@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.9.5.tgz#422661ef0e16699a234465d15b2c1089ef963b2a" + integrity sha512-aHSsMnUw+0UETB0Hlv7B/ZHOGY5bQdwMKJSzGfDfvyhnpmVxLMGnQPGNE9wgqkLUs3+gbG1Qx02S2LLfJ5GaRQ== + +"@rollup/rollup-win32-ia32-msvc@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.9.5.tgz#7b73a145891c202fbcc08759248983667a035d85" + integrity sha512-AiqiLkb9KSf7Lj/o1U3SEP9Zn+5NuVKgFdRIZkvd4N0+bYrTOovVd0+LmYCPQGbocT4kvFyK+LXCDiXPBF3fyA== + +"@rollup/rollup-win32-x64-msvc@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.9.5.tgz#10491ccf4f63c814d4149e0316541476ea603602" + integrity sha512-1q+mykKE3Vot1kaFJIDoUFv5TuW+QQVaf2FmTT9krg86pQrGStOSJJ0Zil7CFagyxDuouTepzt5Y5TVzyajOdQ== "@sinclair/typebox@^0.27.8": version "0.27.8" @@ -2151,7 +2151,7 @@ cross-spawn@^6.0.0: shebang-command "^1.2.0" which "^1.2.9" -cross-spawn@^7.0.0, cross-spawn@^7.0.2, cross-spawn@^7.0.3: +cross-spawn@^7.0.2, cross-spawn@^7.0.3: version "7.0.3" resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.3.tgz#f73a85b9d5d41d045551c177e2882d4ac85728a6" integrity sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w== @@ -2664,22 +2664,7 @@ execa@^1.0.0: signal-exit "^3.0.0" strip-eof "^1.0.0" -execa@^4.1.0: - version "4.1.0" - resolved "https://registry.yarnpkg.com/execa/-/execa-4.1.0.tgz#4e5491ad1572f2f17a77d388c6c857135b22847a" - integrity sha512-j5W0//W7f8UxAn8hXVnwG8tLwdiUy4FJLcSupCg6maBYZDpyBvTApK7KyuI4bKj8KOh1r2YH+6ucuYtJv1bTZA== - dependencies: - cross-spawn "^7.0.0" - get-stream "^5.0.0" - human-signals "^1.1.1" - is-stream "^2.0.0" - merge-stream "^2.0.0" - npm-run-path "^4.0.0" - onetime "^5.1.0" - signal-exit "^3.0.2" - strip-final-newline "^2.0.0" - -execa@^5.0.0: +execa@^5.0.0, execa@^5.1.1: version "5.1.1" resolved "https://registry.yarnpkg.com/execa/-/execa-5.1.1.tgz#f80ad9cbf4298f7bd1d4c9555c21e93741c411dd" integrity sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg== @@ -2931,13 +2916,6 @@ get-stream@^4.0.0: dependencies: pump "^3.0.0" -get-stream@^5.0.0: - version "5.2.0" - resolved "https://registry.yarnpkg.com/get-stream/-/get-stream-5.2.0.tgz#4966a1795ee5ace65e706c4b7beb71257d6e22d3" - integrity sha512-nBF+F1rAZVCu/p7rjzgA+Yb4lfYXrpl7a6VmJrU8wF9I1CKvP/QwPNZHnOlwbTkY6dvtFIzFMSyQXbLoTQPRpA== - dependencies: - pump "^3.0.0" - get-stream@^6.0.0: version "6.0.1" resolved "https://registry.yarnpkg.com/get-stream/-/get-stream-6.0.1.tgz#a262d8eef67aced57c2852ad6167526a43cbf7b7" @@ -3105,20 +3083,15 @@ html-escaper@^2.0.0: resolved "https://registry.yarnpkg.com/html-escaper/-/html-escaper-2.0.2.tgz#dfd60027da36a36dfcbe236262c00a5822681453" integrity sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg== -human-signals@^1.1.1: - version "1.1.1" - resolved "https://registry.yarnpkg.com/human-signals/-/human-signals-1.1.1.tgz#c5b1cd14f50aeae09ab6c59fe63ba3395fe4dfa3" - integrity sha512-SEQu7vl8KjNL2eoGBLF3+wAjpsNfA9XMlXAYj/3EdaNfAlxKthD1xjEQfGOUhllCGGJVNY34bRr6lPINhNjyZw== - human-signals@^2.1.0: version "2.1.0" resolved "https://registry.yarnpkg.com/human-signals/-/human-signals-2.1.0.tgz#dc91fcba42e4d06e4abaed33b3e7a3c02f514ea0" integrity sha512-B4FFZ6q/T2jhhksgkbEW3HBvWIfDW85snkQgawt07S7J5QXTk6BkNV+0yAeZrM5QpMAdYlocGoljn0sJ/WQkFw== -husky@^8: - version "8.0.3" - resolved "https://registry.yarnpkg.com/husky/-/husky-8.0.3.tgz#4936d7212e46d1dea28fef29bb3a108872cd9184" - integrity sha512-+dQSyqPh4x1hlO1swXBiNb2HzTDN1I2IGLQx1GrBuiqFJfoMrnZWwVmatvSiO+Iz8fBUnf+lekwNo4c2LlXItg== +husky@^9.0.0: + version "9.0.9" + resolved "https://registry.yarnpkg.com/husky/-/husky-9.0.9.tgz#3a48d0666bf871de14871865f929a5dceabc07f8" + integrity sha512-eW92PRr1XPKDWd7/iM2JvAl9gEKK3TF69yvbllQtKSYBw+Wtoi+P38NqH1Z7++sSd80FBkFagBFJkoQvMhCnGw== ignore@^5.2.0, ignore@^5.3.0: version "5.3.0" @@ -4301,7 +4274,7 @@ npm-run-path@^2.0.0: dependencies: path-key "^2.0.0" -npm-run-path@^4.0.0, npm-run-path@^4.0.1: +npm-run-path@^4.0.1: version "4.0.1" resolved "https://registry.yarnpkg.com/npm-run-path/-/npm-run-path-4.0.1.tgz#b7ecd1e5ed53da8e37a55e1c2269e0b97ed748ea" integrity sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw== @@ -4417,7 +4390,7 @@ once@^1.3.0, once@^1.3.1, once@^1.4.0: dependencies: wrappy "1" -onetime@^5.1.0, onetime@^5.1.2: +onetime@^5.1.2: version "5.1.2" resolved "https://registry.yarnpkg.com/onetime/-/onetime-5.1.2.tgz#d0e96ebb56b07476df1dd9c4806e5237985ca45e" integrity sha512-kbpaSSGJTWdAY5KPVeMOKXSrPtr8C8C7wodJbcsd51jRnmD+GZu8Y0VoU6Dm5Z4vWr0Ig/1NKuWRKf7j5aaYSg== @@ -4593,13 +4566,13 @@ pretty-format@^29.7.0: ansi-styles "^5.0.0" react-is "^18.0.0" -pretty-quick@^3: - version "3.3.1" - resolved "https://registry.yarnpkg.com/pretty-quick/-/pretty-quick-3.3.1.tgz#cfde97fec77a8d201a0e0c9c71d9990e12587ee2" - integrity sha512-3b36UXfYQ+IXXqex6mCca89jC8u0mYLqFAN5eTQKoXO6oCQYcIVYZEB/5AlBHI7JPYygReM2Vv6Vom/Gln7fBg== +pretty-quick@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/pretty-quick/-/pretty-quick-4.0.0.tgz#ea5cce85a5804bfbec7327b0e064509155d03f39" + integrity sha512-M+2MmeufXb/M7Xw3Afh1gxcYpj+sK0AxEfnfF958ktFeAyi5MsKY5brymVURQLgPLV1QaF5P4pb2oFJ54H3yzQ== dependencies: - execa "^4.1.0" - find-up "^4.1.0" + execa "^5.1.1" + find-up "^5.0.0" ignore "^5.3.0" mri "^1.2.0" picocolors "^1.0.0" @@ -4670,17 +4643,17 @@ react-refresh@^0.14.0: integrity sha512-wViHqhAd8OHeLS/IRMJjTSDHF3U9eWi62F/MledQGPdJGDhodXJ9PBLNGr6WWL7qlH12Mt3TyTpbS+hGXMjCzQ== react-router-dom@^6: - version "6.21.2" - resolved "https://registry.yarnpkg.com/react-router-dom/-/react-router-dom-6.21.2.tgz#5fba851731a194fa32c31990c4829c5e247f650a" - integrity sha512-tE13UukgUOh2/sqYr6jPzZTzmzc70aGRP4pAjG2if0IP3aUT+sBtAKUJh0qMh0zylJHGLmzS+XWVaON4UklHeg== + version "6.21.3" + resolved "https://registry.yarnpkg.com/react-router-dom/-/react-router-dom-6.21.3.tgz#ef3a7956a3699c7b82c21fcb3dbc63c313ed8c5d" + integrity sha512-kNzubk7n4YHSrErzjLK72j0B5i969GsuCGazRl3G6j1zqZBLjuSlYBdVdkDOgzGdPIffUOc9nmgiadTEVoq91g== dependencies: "@remix-run/router" "1.14.2" - react-router "6.21.2" + react-router "6.21.3" -react-router@6.21.2: - version "6.21.2" - resolved "https://registry.yarnpkg.com/react-router/-/react-router-6.21.2.tgz#8820906c609ae7e4e8f926cc8eb5ce161428b956" - integrity sha512-jJcgiwDsnaHIeC+IN7atO0XiSRCrOsQAHHbChtJxmgqG2IaYQXSnhqGb5vk2CU/wBQA12Zt+TkbuJjIn65gzbA== +react-router@6.21.3: + version "6.21.3" + resolved "https://registry.yarnpkg.com/react-router/-/react-router-6.21.3.tgz#8086cea922c2bfebbb49c6594967418f1f167d70" + integrity sha512-a0H638ZXULv1OdkmiK6s6itNhoy33ywxmUFT/xtSoVyf9VnC7n7+VT4LjVzdIHSaF5TIh9ylUgxMXksHTgGrKg== dependencies: "@remix-run/router" "1.14.2" @@ -4845,25 +4818,25 @@ rimraf@^3.0.2: glob "^7.1.3" rollup@^4.2.0: - version "4.9.4" - resolved "https://registry.yarnpkg.com/rollup/-/rollup-4.9.4.tgz#37bc0c09ae6b4538a9c974f4d045bb64b2e7c27c" - integrity sha512-2ztU7pY/lrQyXSCnnoU4ICjT/tCG9cdH3/G25ERqE3Lst6vl2BCM5hL2Nw+sslAvAf+ccKsAq1SkKQALyqhR7g== + version "4.9.5" + resolved "https://registry.yarnpkg.com/rollup/-/rollup-4.9.5.tgz#62999462c90f4c8b5d7c38fc7161e63b29101b05" + integrity sha512-E4vQW0H/mbNMw2yLSqJyjtkHY9dslf/p0zuT1xehNRqUTBOFMqEjguDvqhXr7N7r/4ttb2jr4T41d3dncmIgbQ== dependencies: "@types/estree" "1.0.5" optionalDependencies: - "@rollup/rollup-android-arm-eabi" "4.9.4" - "@rollup/rollup-android-arm64" "4.9.4" - "@rollup/rollup-darwin-arm64" "4.9.4" - "@rollup/rollup-darwin-x64" "4.9.4" - "@rollup/rollup-linux-arm-gnueabihf" "4.9.4" - "@rollup/rollup-linux-arm64-gnu" "4.9.4" - "@rollup/rollup-linux-arm64-musl" "4.9.4" - "@rollup/rollup-linux-riscv64-gnu" "4.9.4" - "@rollup/rollup-linux-x64-gnu" "4.9.4" - "@rollup/rollup-linux-x64-musl" "4.9.4" - "@rollup/rollup-win32-arm64-msvc" "4.9.4" - "@rollup/rollup-win32-ia32-msvc" "4.9.4" - "@rollup/rollup-win32-x64-msvc" "4.9.4" + "@rollup/rollup-android-arm-eabi" "4.9.5" + "@rollup/rollup-android-arm64" "4.9.5" + "@rollup/rollup-darwin-arm64" "4.9.5" + "@rollup/rollup-darwin-x64" "4.9.5" + "@rollup/rollup-linux-arm-gnueabihf" "4.9.5" + "@rollup/rollup-linux-arm64-gnu" "4.9.5" + "@rollup/rollup-linux-arm64-musl" "4.9.5" + "@rollup/rollup-linux-riscv64-gnu" "4.9.5" + "@rollup/rollup-linux-x64-gnu" "4.9.5" + "@rollup/rollup-linux-x64-musl" "4.9.5" + "@rollup/rollup-win32-arm64-msvc" "4.9.5" + "@rollup/rollup-win32-ia32-msvc" "4.9.5" + "@rollup/rollup-win32-x64-msvc" "4.9.5" fsevents "~2.3.2" rsvp@^4.8.4: @@ -5474,9 +5447,9 @@ v8-to-istanbul@^9.0.1: convert-source-map "^1.6.0" vite@^5.0.5: - version "5.0.11" - resolved "https://registry.yarnpkg.com/vite/-/vite-5.0.11.tgz#31562e41e004cb68e1d51f5d2c641ab313b289e4" - integrity sha512-XBMnDjZcNAw/G1gEiskiM1v6yzM4GE5aMGvhWTlHAYYhxb7S3/V1s3m2LDHa8Vh6yIWYYB0iJwsEaS523c4oYA== + version "5.0.12" + resolved "https://registry.yarnpkg.com/vite/-/vite-5.0.12.tgz#8a2ffd4da36c132aec4adafe05d7adde38333c47" + integrity sha512-4hsnEkG3q0N4Tzf1+t6NdN9dg/L3BM+q8SWgbSPnJvrgH2kgdyzfVJwbR1ic69/4uMJJ/3dqDZZE5/WwqW8U1w== dependencies: esbuild "^0.19.3" postcss "^8.4.32" diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java index 93cf9beb70f..814cb48c49f 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java @@ -11,6 +11,7 @@ import com.yahoo.vdslib.state.Node; import com.yahoo.vdslib.state.NodeState; import com.yahoo.vdslib.state.State; import com.yahoo.vespa.clustercontroller.core.hostinfo.HostInfo; +import com.yahoo.vespa.clustercontroller.core.hostinfo.Metrics; import com.yahoo.vespa.clustercontroller.core.hostinfo.StorageNode; import com.yahoo.vespa.clustercontroller.utils.staterestapi.requests.SetUnitStateRequest; @@ -25,6 +26,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -49,7 +51,9 @@ public class NodeStateChangeChecker { private static final Logger log = Logger.getLogger(NodeStateChangeChecker.class.getName()); private static final String BUCKETS_METRIC_NAME = StorageMetrics.VDS_DATASTORED_BUCKET_SPACE_BUCKETS_TOTAL.baseName(); - private static final Map<String, String> BUCKETS_METRIC_DIMENSIONS = Map.of("bucketSpace", "default"); + private static final String ENTRIES_METRIC_NAME = StorageMetrics.VDS_DATASTORED_BUCKET_SPACE_ENTRIES.baseName(); + private static final String DOCS_METRIC_NAME = StorageMetrics.VDS_DATASTORED_BUCKET_SPACE_DOCS.baseName(); + private static final Map<String, String> DEFAULT_SPACE_METRIC_DIMENSIONS = Map.of("bucketSpace", "default"); private final int requiredRedundancy; private final HierarchicalGroupVisiting groupVisiting; @@ -107,6 +111,50 @@ public class NodeStateChangeChecker { && Objects.equals(newWantedState.getDescription(), oldWantedState.getDescription()); } + private record NodeDataMetrics(Optional<Metrics.Value> buckets, + Optional<Metrics.Value> entries, + Optional<Metrics.Value> docs) {} + + private static NodeDataMetrics dataMetricsFromHostInfo(HostInfo hostInfo) { + return new NodeDataMetrics( + hostInfo.getMetrics().getValueAt(BUCKETS_METRIC_NAME, DEFAULT_SPACE_METRIC_DIMENSIONS), + hostInfo.getMetrics().getValueAt(ENTRIES_METRIC_NAME, DEFAULT_SPACE_METRIC_DIMENSIONS), + hostInfo.getMetrics().getValueAt(DOCS_METRIC_NAME, DEFAULT_SPACE_METRIC_DIMENSIONS)); + } + + private static Optional<Result> checkZeroEntriesStoredOnContentNode(NodeDataMetrics metrics, int nodeIndex) { + if (metrics.entries.isEmpty() || metrics.entries.get().getLast() == null) { + // To allow for rolling upgrades in clusters with content node versions that do not report + // an entry count, defer to legacy bucket count check if the entry metric can't be found. + return Optional.empty(); + } + if (metrics.docs.isEmpty() || metrics.docs.get().getLast() == null) { + log.log(Level.WARNING, "Host info inconsistency: storage node %d reports entry count but not document count".formatted(nodeIndex)); + return Optional.of(disallow("The storage node host info reports stored entry count, but not document count")); + } + long lastEntries = metrics.entries.get().getLast(); + long lastDocs = metrics.docs.get().getLast(); + if (lastEntries != 0) { + long buckets = metrics.buckets.map(Metrics.Value::getLast).orElse(-1L); + long tombstones = lastEntries - lastDocs; // docs are a subset of entries, so |docs| <= |entries| + return Optional.of(disallow("The storage node stores %d documents and %d tombstones across %d buckets".formatted(lastDocs, tombstones, buckets))); + } + // At this point we believe we have zero entries. Cross-check with visible doc count; it should + // always be present when an entry count of zero is present and transitively always be zero. + if (lastDocs != 0) { + log.log(Level.WARNING, "Host info inconsistency: storage node %d reports 0 entries, but %d documents".formatted(nodeIndex, lastDocs)); + return Optional.of(disallow("The storage node reports 0 entries, but %d documents".formatted(lastDocs))); + } + return Optional.of(allow()); + } + + private static Result checkLegacyZeroBucketsStoredOnContentNode(long lastBuckets) { + if (lastBuckets != 0) { + return disallow("The storage node manages %d buckets".formatted(lastBuckets)); + } + return allow(); + } + private Result canSetStateDownPermanently(NodeInfo nodeInfo, ClusterState clusterState, String newDescription) { var result = checkIfStateSetWithDifferentDescription(nodeInfo, newDescription); if (result.notAllowed()) @@ -129,15 +177,20 @@ public class NodeStateChangeChecker { " got info for storage node " + nodeIndex + " at a different version " + hostInfoNodeVersion); - var bucketsMetric = hostInfo.getMetrics().getValueAt(BUCKETS_METRIC_NAME, BUCKETS_METRIC_DIMENSIONS); - if (bucketsMetric.isEmpty() || bucketsMetric.get().getLast() == null) + var metrics = dataMetricsFromHostInfo(hostInfo); + // Bucket count metric should always be present + if (metrics.buckets.isEmpty() || metrics.buckets.get().getLast() == null) { return disallow("Missing last value of the " + BUCKETS_METRIC_NAME + " metric for storage node " + nodeIndex); + } - long lastBuckets = bucketsMetric.get().getLast(); - if (lastBuckets > 0) - return disallow("The storage node manages " + lastBuckets + " buckets"); - - return allow(); + // TODO should also ideally check merge pending from the distributors' perspectives. + // - This goes in particular for the global space, as we only check for zero entries in + // the _default_ space (as global entries are retained even for retired nodes). + // - Due to global merges being prioritized above everything else, it is highly unlikely + // that there will be any pending global merges, but the possibility still exists. + // - Would need wiring of aggregated content cluster stats + var entriesCheckResult = checkZeroEntriesStoredOnContentNode(metrics, nodeIndex); + return entriesCheckResult.orElseGet(() -> checkLegacyZeroBucketsStoredOnContentNode(metrics.buckets.get().getLast())); } private Result canSetStateUp(NodeInfo nodeInfo, NodeState oldWantedState) { diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeCheckerTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeCheckerTest.java index 199a23f49ba..c1acc19ae9e 100644 --- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeCheckerTest.java +++ b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeCheckerTest.java @@ -1,7 +1,6 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.clustercontroller.core; -import com.yahoo.log.LogSetup; import com.yahoo.vdslib.distribution.ConfiguredNode; import com.yahoo.vdslib.distribution.Distribution; import com.yahoo.vdslib.state.ClusterState; @@ -15,6 +14,8 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static com.yahoo.vdslib.state.NodeType.DISTRIBUTOR; import static com.yahoo.vdslib.state.NodeType.STORAGE; @@ -709,6 +710,10 @@ public class NodeStateChangeCheckerTest { assertFalse(result.isAlreadySet()); } + private record HostInfoMetrics(int bucketCount, Integer entryCountOrNull, Integer docCountOrNull) { + static HostInfoMetrics zero() { return new HostInfoMetrics(0, 0, 0); } + } + @ParameterizedTest @ValueSource(ints = {-1, 1}) void testDownDisallowedByNonRetiredState(int maxNumberOfGroupsAllowedToBeDown) { @@ -716,25 +721,43 @@ public class NodeStateChangeCheckerTest { defaultAllUpClusterState(), UP, currentClusterStateVersion, - 0, + HostInfoMetrics.zero(), maxNumberOfGroupsAllowedToBeDown); assertFalse(result.allowed()); assertFalse(result.isAlreadySet()); assertEquals("Only retired nodes are allowed to be set to DOWN in safe mode - is Up", result.reason()); } + private record MetricsAndMessage(HostInfoMetrics hostInfoMetrics, String expectedMessage) {} + @ParameterizedTest @ValueSource(ints = {-1, 1}) - void testDownDisallowedByBuckets(int maxNumberOfGroupsAllowedToBeDown) { - Result result = evaluateDownTransition( - retiredClusterStateSuffix(), - UP, - currentClusterStateVersion, - 1, - maxNumberOfGroupsAllowedToBeDown); - assertFalse(result.allowed()); - assertFalse(result.isAlreadySet()); - assertEquals("The storage node manages 1 buckets", result.reason()); + void down_disallowed_by_host_info_metrics_implying_node_still_stores_data(int maxNumberOfGroupsAllowedToBeDown) { + var disallowCases = List.of( + // Non-zero bucket count, and no entry/doc count metrics + new MetricsAndMessage(new HostInfoMetrics(1, null, null), "The storage node manages 1 buckets"), + // Non-zero bucket count and non-zero entries (note that we prefer reporting the entry count over the bucket count) + new MetricsAndMessage(new HostInfoMetrics(1, 2, 1), "The storage node stores 1 documents and 1 tombstones across 1 buckets"), + + // These are cases that should not normally happen, but we test them nevertheless: + // Bucket count should never be zero if the entry count is > 0 + new MetricsAndMessage(new HostInfoMetrics(0, 2, 1), "The storage node stores 1 documents and 1 tombstones across 0 buckets"), + // Entry count should never be zero if the document count is > 0 + new MetricsAndMessage(new HostInfoMetrics(0, 0, 2), "The storage node reports 0 entries, but 2 documents"), + // Document count should always be present alongside entry count + new MetricsAndMessage(new HostInfoMetrics(0, 0, null), "The storage node host info reports stored entry count, but not document count") + ); + for (var dc : disallowCases) { + Result result = evaluateDownTransition( + retiredClusterStateSuffix(), + UP, + currentClusterStateVersion, + dc.hostInfoMetrics, + maxNumberOfGroupsAllowedToBeDown); + assertFalse(result.allowed()); + assertFalse(result.isAlreadySet()); + assertEquals(dc.expectedMessage, result.reason()); + } } @ParameterizedTest @@ -744,7 +767,7 @@ public class NodeStateChangeCheckerTest { retiredClusterStateSuffix(), INITIALIZING, currentClusterStateVersion, - 0, + HostInfoMetrics.zero(), maxNumberOfGroupsAllowedToBeDown); assertFalse(result.allowed()); assertFalse(result.isAlreadySet()); @@ -758,7 +781,7 @@ public class NodeStateChangeCheckerTest { retiredClusterStateSuffix(), UP, currentClusterStateVersion - 1, - 0, + HostInfoMetrics.zero(), maxNumberOfGroupsAllowedToBeDown); assertFalse(result.allowed()); assertFalse(result.isAlreadySet()); @@ -766,14 +789,42 @@ public class NodeStateChangeCheckerTest { result.reason()); } + // Legacy fallback when the content node does not report stored entry count (docs + tombstones) + @ParameterizedTest + @ValueSource(ints = {-1, 1}) + void allowed_to_set_down_when_no_buckets_without_entry_metrics(int maxNumberOfGroupsAllowedToBeDown) { + Result result = evaluateDownTransition( + retiredClusterStateSuffix(), + UP, + currentClusterStateVersion, + new HostInfoMetrics(0, null, null), + maxNumberOfGroupsAllowedToBeDown); + assertTrue(result.allowed()); + assertFalse(result.isAlreadySet()); + } + + @ParameterizedTest + @ValueSource(ints = {-1, 1}) + void allowed_to_set_down_when_no_stored_entries_or_buckets(int maxNumberOfGroupsAllowedToBeDown) { + Result result = evaluateDownTransition( + retiredClusterStateSuffix(), + UP, + currentClusterStateVersion, + HostInfoMetrics.zero(), + maxNumberOfGroupsAllowedToBeDown); + assertTrue(result.allowed()); + assertFalse(result.isAlreadySet()); + } + @ParameterizedTest @ValueSource(ints = {-1, 1}) - void testAllowedToSetDown(int maxNumberOfGroupsAllowedToBeDown) { + void allowed_to_set_down_when_no_stored_entries_but_empty_buckets_are_present(int maxNumberOfGroupsAllowedToBeDown) { Result result = evaluateDownTransition( retiredClusterStateSuffix(), UP, currentClusterStateVersion, - 0, + // The node has (orphaned) buckets, but nothing is stored in them + new HostInfoMetrics(100, 0, 0), maxNumberOfGroupsAllowedToBeDown); assertTrue(result.allowed()); assertFalse(result.isAlreadySet()); @@ -782,14 +833,14 @@ public class NodeStateChangeCheckerTest { private Result evaluateDownTransition(ClusterState clusterState, State reportedState, int hostInfoClusterStateVersion, - int lastAlldisksBuckets, + HostInfoMetrics hostInfoMetrics, int maxNumberOfGroupsAllowedToBeDown) { ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown); NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(cluster); StorageNodeInfo nodeInfo = cluster.clusterInfo().getStorageNodeInfo(nodeStorage.getIndex()); nodeInfo.setReportedState(new NodeState(STORAGE, reportedState), 0); - nodeInfo.setHostInfo(createHostInfoWithMetrics(hostInfoClusterStateVersion, lastAlldisksBuckets)); + nodeInfo.setHostInfo(createHostInfoWithMetrics(hostInfoClusterStateVersion, hostInfoMetrics)); return nodeStateChangeChecker.evaluateTransition( nodeStorage, clusterState, SAFE, @@ -802,90 +853,68 @@ public class NodeStateChangeCheckerTest { nodeStorage.getIndex())); } - private static HostInfo createHostInfoWithMetrics(int clusterStateVersion, int lastAlldisksBuckets) { - return HostInfo.createHostInfo(String.format("{\n" + - " \"metrics\":\n" + - " {\n" + - " \"snapshot\":\n" + - " {\n" + - " \"from\":1494940706,\n" + - " \"to\":1494940766\n" + - " },\n" + - " \"values\":\n" + - " [\n" + - " {\n" + - " \"name\":\"vds.datastored.alldisks.buckets\",\n" + - " \"description\":\"buckets managed\",\n" + - " \"values\":\n" + - " {\n" + - " \"average\":262144.0,\n" + - " \"count\":1,\n" + - " \"rate\":0.016666,\n" + - " \"min\":262144,\n" + - " \"max\":262144,\n" + - " \"last\":%d\n" + - " },\n" + - " \"dimensions\":\n" + - " {\n" + - " }\n" + - " },\n" + - " {\n" + - " \"name\":\"vds.datastored.alldisks.docs\",\n" + - " \"description\":\"documents stored\",\n" + - " \"values\":\n" + - " {\n" + - " \"average\":154689587.0,\n" + - " \"count\":1,\n" + - " \"rate\":0.016666,\n" + - " \"min\":154689587,\n" + - " \"max\":154689587,\n" + - " \"last\":154689587\n" + - " },\n" + - " \"dimensions\":\n" + - " {\n" + - " }\n" + - " },\n" + - " {\n" + - " \"name\":\"vds.datastored.bucket_space.buckets_total\",\n" + - " \"description\":\"Total number buckets present in the bucket space (ready + not ready)\",\n" + - " \"values\":\n" + - " {\n" + - " \"average\":0.0,\n" + - " \"sum\":0.0,\n" + - " \"count\":1,\n" + - " \"rate\":0.016666,\n" + - " \"min\":0,\n" + - " \"max\":0,\n" + - " \"last\":0\n" + - " },\n" + - " \"dimensions\":\n" + - " {\n" + - " \"bucketSpace\":\"global\"\n" + - " }\n" + - " },\n" + - " {\n" + - " \"name\":\"vds.datastored.bucket_space.buckets_total\",\n" + - " \"description\":\"Total number buckets present in the bucket space (ready + not ready)\",\n" + - " \"values\":\n" + - " {\n" + - " \"average\":129.0,\n" + - " \"sum\":129.0,\n" + - " \"count\":1,\n" + - " \"rate\":0.016666,\n" + - " \"min\":129,\n" + - " \"max\":129,\n" + - " \"last\":%d\n" + - " },\n" + - " \"dimensions\":\n" + - " {\n" + - " \"bucketSpace\":\"default\"\n" + - " }\n" + - " }\n" + - " ]\n" + - " },\n" + - " \"cluster-state-version\":%d\n" + - "}", - lastAlldisksBuckets, lastAlldisksBuckets, clusterStateVersion)); + private static String bucketSpacesMetricJsonIfNonNull(String metric, Integer lastValueOrNull) { + if (lastValueOrNull != null) { + // We fake the value for the global space; its actual value does not matter, as global + // document [entry] count is not taken into consideration for node removals (they are never + // moved away during retirement). We just want a unique value to test that we don't + // accidentally use the wrong dimension. + return Stream.of("default", "global") + .map(bucketSpace -> + """ + { + "name":"vds.datastored.bucket_space.%s", + "values":{"last":%d}, + "dimensions":{"bucketSpace":"%s"} + },""".formatted(metric, (lastValueOrNull + (bucketSpace.equals("default") ? 0 : 123)), bucketSpace)) + .collect(Collectors.joining("\n")); + } + return ""; + } + + private static HostInfo createHostInfoWithMetrics(int clusterStateVersion, HostInfoMetrics hostInfoMetrics) { + return HostInfo.createHostInfo(String.format(""" + { + "metrics": + { + "snapshot": + { + "from":1494940706, + "to":1494940766 + }, + "values": + [ + %s + %s + %s + { + "name":"vds.datastored.alldisks.buckets", + "values": + { + "last":%d + }, + "dimensions": + { + } + }, + { + "name":"vds.datastored.alldisks.docs", + "values": + { + "last":154689587 + }, + "dimensions": + { + } + } + ] + }, + "cluster-state-version":%d + }""", + bucketSpacesMetricJsonIfNonNull("buckets_total", hostInfoMetrics.bucketCount), + bucketSpacesMetricJsonIfNonNull("entries", hostInfoMetrics.entryCountOrNull), + bucketSpacesMetricJsonIfNonNull("docs", hostInfoMetrics.docCountOrNull), + hostInfoMetrics.bucketCount, clusterStateVersion)); } private List<ConfiguredNode> createNodes(int count) { diff --git a/config-model-api/abi-spec.json b/config-model-api/abi-spec.json index 21374061bfa..f5f9215d29d 100644 --- a/config-model-api/abi-spec.json +++ b/config-model-api/abi-spec.json @@ -1517,7 +1517,6 @@ "abstract" ], "methods" : [ - "public com.yahoo.config.model.api.OnnxModelCost$Calculator newCalculator(com.yahoo.config.application.api.ApplicationPackage, com.yahoo.config.provision.ApplicationId)", "public abstract com.yahoo.config.model.api.OnnxModelCost$Calculator newCalculator(com.yahoo.config.application.api.ApplicationPackage, com.yahoo.config.provision.ApplicationId, com.yahoo.config.provision.ClusterSpec$Id)", "public static com.yahoo.config.model.api.OnnxModelCost disabled()" ], diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java index eb5942bd49c..a6910c059fc 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java @@ -112,8 +112,8 @@ public interface ModelContext { @ModelFeatureFlag(owners = {"bjorncs"}, removeAfter = "8.289") default boolean dynamicHeapSize() { return true; } @ModelFeatureFlag(owners = {"hmusum"}) default String unknownConfigDefinition() { return "warn"; } @ModelFeatureFlag(owners = {"hmusum"}) default int searchHandlerThreadpool() { return 2; } - @ModelFeatureFlag(owners = {"vekterli"}) default long mergingMaxMemoryUsagePerNode() { return -1; } - @ModelFeatureFlag(owners = {"vekterli"}) default boolean usePerDocumentThrottledDeleteBucket() { return false; } + @ModelFeatureFlag(owners = {"vekterli"}, removeAfter = "8.292.x") default long mergingMaxMemoryUsagePerNode() { return 0; } + @ModelFeatureFlag(owners = {"vekterli"}, removeAfter = "8.292.x") default boolean usePerDocumentThrottledDeleteBucket() { return true; } @ModelFeatureFlag(owners = {"baldersheim"}) default boolean alwaysMarkPhraseExpensive() { return false; } @ModelFeatureFlag(owners = {"hmusum"}) default boolean restartOnDeployWhenOnnxModelChanges() { return false; } @ModelFeatureFlag(owners = {"baldersheim"}) default boolean sortBlueprintsByCost() { return false; } diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java index 1efd98184cc..49ef3cf4929 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java @@ -15,10 +15,6 @@ import java.util.Map; */ public interface OnnxModelCost { - // TODO: Remove when no longer in use (oldest model version is 8.283) - default Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { - return newCalculator(appPkg, applicationId, null); - } Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId, ClusterSpec.Id clusterId); interface Calculator { diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java index 6f0254145d6..b088231e84a 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java @@ -84,8 +84,6 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea private List<DataplaneToken> dataplaneTokens; private int contentLayerMetadataFeatureLevel = 0; private boolean dynamicHeapSize = false; - private long mergingMaxMemoryUsagePerNode = -1; - private boolean usePerDocumentThrottledDeleteBucket = false; private boolean restartOnDeployWhenOnnxModelChanges = false; @Override public ModelContext.FeatureFlags featureFlags() { return this; } @@ -144,8 +142,6 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea @Override public List<DataplaneToken> dataplaneTokens() { return dataplaneTokens; } @Override public int contentLayerMetadataFeatureLevel() { return contentLayerMetadataFeatureLevel; } @Override public boolean dynamicHeapSize() { return dynamicHeapSize; } - @Override public long mergingMaxMemoryUsagePerNode() { return mergingMaxMemoryUsagePerNode; } - @Override public boolean usePerDocumentThrottledDeleteBucket() { return usePerDocumentThrottledDeleteBucket; } @Override public boolean restartOnDeployWhenOnnxModelChanges() { return restartOnDeployWhenOnnxModelChanges; } public TestProperties sharedStringRepoNoReclaim(boolean sharedStringRepoNoReclaim) { @@ -379,16 +375,6 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea public TestProperties setDynamicHeapSize(boolean b) { this.dynamicHeapSize = b; return this; } - public TestProperties setMergingMaxMemoryUsagePerNode(long maxUsage) { - this.mergingMaxMemoryUsagePerNode = maxUsage; - return this; - } - - public TestProperties setUsePerDocumentThrottledDeleteBucket(boolean enableThrottling) { - this.usePerDocumentThrottledDeleteBucket = enableThrottling; - return this; - } - public TestProperties setRestartOnDeployForOnnxModelChanges(boolean enable) { this.restartOnDeployWhenOnnxModelChanges = enable; return this; diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index 502b054f84e..9b3e236612a 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -615,18 +615,6 @@ public class RankProfile implements Cloneable { .orElse(Set.of()); } - private void addSummaryFeature(ReferenceNode feature) { - if (summaryFeatures == null) - summaryFeatures = new LinkedHashSet<>(); - summaryFeatures.add(feature); - } - - private void addMatchFeature(ReferenceNode feature) { - if (matchFeatures == null) - matchFeatures = new LinkedHashSet<>(); - matchFeatures.add(feature); - } - private void addImplicitMatchFeatures(List<FeatureList> list) { if (hiddenMatchFeatures == null) hiddenMatchFeatures = new LinkedHashSet<>(); @@ -642,15 +630,19 @@ public class RankProfile implements Cloneable { /** Adds the content of the given feature list to the internal list of summary features. */ public void addSummaryFeatures(FeatureList features) { + if (summaryFeatures == null) + summaryFeatures = new LinkedHashSet<>(); for (ReferenceNode feature : features) { - addSummaryFeature(feature); + summaryFeatures.add(feature); } } /** Adds the content of the given feature list to the internal list of match features. */ public void addMatchFeatures(FeatureList features) { + if (matchFeatures == null) + matchFeatures = new LinkedHashSet<>(); for (ReferenceNode feature : features) { - addMatchFeature(feature); + matchFeatures.add(feature); } } @@ -661,20 +653,16 @@ public class RankProfile implements Cloneable { .orElse(Set.of()); } - private void addRankFeature(ReferenceNode feature) { - if (rankFeatures == null) - rankFeatures = new LinkedHashSet<>(); - rankFeatures.add(feature); - } - /** * Adds the content of the given feature list to the internal list of rank features. * * @param features The features to add. */ public void addRankFeatures(FeatureList features) { + if (rankFeatures == null) + rankFeatures = new LinkedHashSet<>(); for (ReferenceNode feature : features) { - addRankFeature(feature); + rankFeatures.add(feature); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyContainerCluster.java index 348b84367d5..0d85696d503 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyContainerCluster.java @@ -160,7 +160,10 @@ public class MetricsProxyContainerCluster extends ContainerCluster<MetricsProxyC builder.consumer.add(toConsumerBuilder(MetricsConsumer.defaultConsumer)); builder.consumer.add(toConsumerBuilder(newDefaultConsumer())); - if (isHostedVespa()) builder.consumer.add(toConsumerBuilder(MetricsConsumer.vespa9)); + if (isHostedVespa()) { + var amendedVespa9Consumer = addMetrics(MetricsConsumer.vespa9, getAdditionalDefaultMetrics().getMetrics()); + builder.consumer.add(toConsumerBuilder(amendedVespa9Consumer)); + } getAdmin() .map(Admin::getAmendedMetricsConsumers) .map(consumers -> consumers.stream().map(ConsumersConfigGenerator::toConsumerBuilder).toList()) @@ -200,7 +203,7 @@ public class MetricsProxyContainerCluster extends ContainerCluster<MetricsProxyC private Optional<String> getSystemName() { Monitoring monitoring = getMonitoringService(); - return monitoring != null && ! monitoring.getClustername().equals("") ? + return monitoring != null && !monitoring.getClustername().isEmpty() ? Optional.of(monitoring.getClustername()) : Optional.empty(); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java index f2d9c0fcd1c..98adde7b547 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java @@ -63,18 +63,14 @@ public class RankSetupValidator implements Validator { context.deployState().getProperties().applicationId().toFullString() + ".") .toFile(); - for (SearchCluster cluster : context.model().getSearchClusters()) { - // Skipping ranking expression checking for streaming clusters, not implemented yet - if (cluster.isStreaming()) continue; - - IndexedSearchCluster sc = (IndexedSearchCluster) cluster; + for (SearchCluster sc : context.model().getSearchClusters()) { String clusterDir = cfgDir.getAbsolutePath() + "/" + sc.getClusterName() + "/"; for (DocumentDatabase docDb : sc.getDocumentDbs()) { String schemaName = docDb.getDerivedConfiguration().getSchema().getName(); String schemaDir = clusterDir + schemaName + "/"; writeConfigs(schemaDir, docDb); writeExtraVerifyRankSetupConfig(schemaDir, docDb); - if (!validate(context, "dir:" + schemaDir, sc, schemaName, cfgDir)) { + if (!validate(context, "dir:" + schemaDir, sc, schemaName, cfgDir, sc.isStreaming())) { return; } } @@ -87,11 +83,11 @@ public class RankSetupValidator implements Validator { } } - private boolean validate(Context context, String configId, SearchCluster searchCluster, String schema, File tempDir) { + private boolean validate(Context context, String configId, SearchCluster searchCluster, String schema, File tempDir, boolean isStreaming) { Instant start = Instant.now(); try { log.log(Level.FINE, () -> String.format("Validating schema '%s' for cluster %s with config id %s", schema, searchCluster, configId)); - boolean ret = execValidate(context, configId, searchCluster, schema); + boolean ret = execValidate(context, configId, searchCluster, schema, isStreaming); if (!ret) { // Give up, don't log same error msg repeatedly deleteTempDir(tempDir); @@ -175,8 +171,8 @@ public class RankSetupValidator implements Validator { IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false); } - private boolean execValidate(Context context, String configId, SearchCluster sc, String sdName) { - String command = String.format("%s %s", binaryName, configId); + private boolean execValidate(Context context, String configId, SearchCluster sc, String sdName, boolean isStreaming) { + String command = String.format((isStreaming ? "%s %s -S" : "%s %s"), binaryName, configId); try { Pair<Integer, String> ret = new ProcessExecuter(true).exec(command); Integer exitCode = ret.getFirst(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/VespaDomBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/VespaDomBuilder.java index 1a5041f44ac..973ebc8c602 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/VespaDomBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/VespaDomBuilder.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model.builder.xml.dom; import ai.vespa.validation.Validation; +import com.yahoo.component.Version; import com.yahoo.config.model.ApplicationConfigProducerRoot; import com.yahoo.config.model.ConfigModelRepo; import com.yahoo.config.model.builder.xml.XmlHelper; @@ -62,6 +63,8 @@ public class VespaDomBuilder extends VespaModelBuilder { return new DomRootBuilder(name). build(deployState, parent, XmlHelper.getDocument(deployState.getApplicationPackage().getServices(), "services.xml") .getDocumentElement()); + } catch (IllegalArgumentException e) { + throw e; } catch (Exception e) { throw new IllegalArgumentException(e); } @@ -204,6 +207,7 @@ public class VespaDomBuilder extends VespaModelBuilder { @Override protected ApplicationConfigProducerRoot doBuild(DeployState deployState, TreeConfigProducer<AnyConfigProducer> parent, Element producerSpec) { + verifyMinimumRequiredVespaVersion(deployState.getVespaVersion(), producerSpec); ApplicationConfigProducerRoot root = new ApplicationConfigProducerRoot(parent, name, deployState.getDocumentModel(), @@ -215,6 +219,17 @@ public class VespaDomBuilder extends VespaModelBuilder { new Client(root); return root; } + + private static void verifyMinimumRequiredVespaVersion(Version thisVersion, Element producerSpec) { + var minimumRequiredVespaVersion = producerSpec.getAttribute("minimum-required-vespa-version"); + if (minimumRequiredVespaVersion.isEmpty()) return; + if (Version.fromString(minimumRequiredVespaVersion).compareTo(thisVersion) > 0) + throw new IllegalArgumentException( + ("Cannot deploy application, minimum required Vespa version is specified as %s in services.xml" + + ", this Vespa version is %s.") + .formatted(minimumRequiredVespaVersion, thisVersion.toFullString())); + } + } /** diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/Client.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/Client.java index 29222817d17..e4abef4eb33 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/http/Client.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/Client.java @@ -4,28 +4,36 @@ package com.yahoo.vespa.model.container.http; import com.yahoo.config.provision.DataplaneToken; import java.security.cert.X509Certificate; +import java.util.Collection; import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.yahoo.vespa.model.container.http.Client.Permission.READ; +import static com.yahoo.vespa.model.container.http.Client.Permission.WRITE; /** * Represents a client. The client is identified by one of the provided certificates and have a set of permissions. * * @author mortent + * @author bjorncs */ public class Client { private final String id; - private final List<String> permissions; + private final Set<Permission> permissions; private final List<X509Certificate> certificates; private final List<DataplaneToken> tokens; private final boolean internal; - public Client(String id, List<String> permissions, List<X509Certificate> certificates, List<DataplaneToken> tokens) { + public Client(String id, Collection<Permission> permissions, List<X509Certificate> certificates, List<DataplaneToken> tokens) { this(id, permissions, certificates, tokens, false); } - private Client(String id, List<String> permissions, List<X509Certificate> certificates, List<DataplaneToken> tokens, + private Client(String id, Collection<Permission> permissions, List<X509Certificate> certificates, List<DataplaneToken> tokens, boolean internal) { this.id = id; - this.permissions = List.copyOf(permissions); + this.permissions = Set.copyOf(permissions); this.certificates = List.copyOf(certificates); this.tokens = List.copyOf(tokens); this.internal = internal; @@ -35,7 +43,7 @@ public class Client { return id; } - public List<String> permissions() { + public Set<Permission> permissions() { return permissions; } @@ -50,6 +58,29 @@ public class Client { } public static Client internalClient(List<X509Certificate> certificates) { - return new Client("_internal", List.of("read","write"), certificates, List.of(), true); + return new Client("_internal", Set.of(READ, WRITE), certificates, List.of(), true); + } + + public enum Permission { + READ, WRITE; + + public String asString() { + return switch (this) { + case READ -> "read"; + case WRITE -> "write"; + }; + } + + public static Permission fromString(String v) { + return switch (v) { + case "read" -> READ; + case "write" -> WRITE; + default -> throw new IllegalArgumentException("Invalid permission '%s'. Valid values are 'read' and 'write'.".formatted(v)); + }; + } + + public static Set<Permission> fromCommaSeparatedString(String str) { + return Stream.of(str.split(",")).map(v -> Permission.fromString(v.strip())).collect(Collectors.toSet()); + } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilter.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilter.java index a1b569fa110..0574e13e387 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilter.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilter.java @@ -46,7 +46,7 @@ class CloudDataPlaneFilter extends Filter implements CloudDataPlaneFilterConfig. .map(x -> new CloudDataPlaneFilterConfig.Clients.Builder() .id(x.id()) .certificates(x.certificates().stream().map(X509CertificateUtils::toPem).toList()) - .permissions(x.permissions())) + .permissions(x.permissions().stream().map(Client.Permission::asString).sorted().toList())) .toList(); builder.clients(clientsCfg).legacyMode(false); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilter.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilter.java index bb24f96784e..e2a522103e6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilter.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilter.java @@ -44,7 +44,7 @@ class CloudTokenDataPlaneFilter extends Filter implements CloudTokenDataPlaneFil .map(x -> new CloudTokenDataPlaneFilterConfig.Clients.Builder() .id(x.id()) .tokens(tokensConfig(x.tokens())) - .permissions(x.permissions())) + .permissions(x.permissions().stream().map(Client.Permission::asString).sorted().toList())) .toList(); builder.clients(clientsCfg).tokenContext(tokenContext); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index e4038a5bca6..8eca29215d4 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -518,10 +518,8 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { String clientId = XML.attribute("id", clientElement).orElseThrow(); if (clientId.startsWith("_")) throw new IllegalArgumentException("Invalid client id '%s', id cannot start with '_'".formatted(clientId)); - List<String> permissions = XML.attribute("permissions", clientElement) - .map(p -> p.split(",")).stream() - .flatMap(Arrays::stream) - .toList(); + var permissions = XML.attribute("permissions", clientElement) + .map(Client.Permission::fromCommaSeparatedString).orElse(Set.of()); var certificates = XML.getChildren(clientElement, "certificate").stream() .flatMap(certElem -> { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/DistributorCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/DistributorCluster.java index 2be65a946b4..465e2397d00 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/DistributorCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/DistributorCluster.java @@ -127,7 +127,6 @@ public class DistributorCluster extends TreeConfigProducer<Distributor> implemen .selectiontoremove("not (" + gc.selection + ")") .interval(gc.interval)); } - builder.enable_revert(parent.getPersistence().supportRevert()); builder.disable_bucket_activation(!hasIndexedDocumentType); builder.max_activation_inhibited_out_of_sync_groups(maxActivationInhibitedOutOfSyncGroups); builder.enable_condition_probing(true); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/engines/DummyPersistence.java b/config-model/src/main/java/com/yahoo/vespa/model/content/engines/DummyPersistence.java index 3b89f26d275..61dd4380f1b 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/engines/DummyPersistence.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/engines/DummyPersistence.java @@ -26,11 +26,6 @@ public class DummyPersistence extends PersistenceEngine { } @Override - public boolean supportRevert() { - return true; - } - - @Override public boolean enableMultiLevelSplitting() { return true; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/engines/PersistenceEngine.java b/config-model/src/main/java/com/yahoo/vespa/model/content/engines/PersistenceEngine.java index fb075c500cf..e80577184be 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/engines/PersistenceEngine.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/engines/PersistenceEngine.java @@ -24,14 +24,6 @@ public abstract class PersistenceEngine extends TreeConfigProducer<AnyConfigProd PersistenceEngine create(DeployState deployState, StorageNode storageNode, StorageGroup parentGroup, ModelElement storageNodeElement); /** - * If a write request succeeds on some nodes and fails on others, causing request to - * fail to client, the content layer will revert the operation where it succeeded if - * reverts are supported. (Typically require backend to keep multiple entries of the - * same document identifier persisted at the same time) - */ - boolean supportRevert(); - - /** * Multi level splitting can increase split performance a lot where documents have been * co-localized, for backends where retrieving document identifiers contained in bucket * is cheap. Backends where split is cheaper than fetching document identifiers will diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/engines/ProtonEngine.java b/config-model/src/main/java/com/yahoo/vespa/model/content/engines/ProtonEngine.java index c4e99649ede..860f2534736 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/engines/ProtonEngine.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/engines/ProtonEngine.java @@ -27,11 +27,6 @@ public class ProtonEngine { } @Override - public boolean supportRevert() { - return false; - } - - @Override public boolean enableMultiLevelSplitting() { return false; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/FileStorProducer.java b/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/FileStorProducer.java index 18b9129cead..4f283d6d9c3 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/FileStorProducer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/FileStorProducer.java @@ -47,7 +47,6 @@ public class FileStorProducer implements StorFilestorConfig.Producer { private final int responseNumThreads; private final StorFilestorConfig.Response_sequencer_type.Enum responseSequencerType; private final boolean useAsyncMessageHandlingOnSchedule; - private final boolean usePerDocumentThrottledDeleteBucket; private static StorFilestorConfig.Response_sequencer_type.Enum convertResponseSequencerType(String sequencerType) { try { @@ -63,9 +62,9 @@ public class FileStorProducer implements StorFilestorConfig.Producer { this.responseNumThreads = featureFlags.defaultNumResponseThreads(); this.responseSequencerType = convertResponseSequencerType(featureFlags.responseSequencerType()); this.useAsyncMessageHandlingOnSchedule = featureFlags.useAsyncMessageHandlingOnSchedule(); - this.usePerDocumentThrottledDeleteBucket = featureFlags.usePerDocumentThrottledDeleteBucket(); } + @Override public void getConfig(StorFilestorConfig.Builder builder) { if (numThreads != null) { @@ -75,7 +74,6 @@ public class FileStorProducer implements StorFilestorConfig.Producer { builder.num_response_threads(responseNumThreads); builder.response_sequencer_type(responseSequencerType); builder.use_async_message_handling_on_schedule(useAsyncMessageHandlingOnSchedule); - builder.use_per_document_throttled_delete_bucket(usePerDocumentThrottledDeleteBucket); var throttleBuilder = new StorFilestorConfig.Async_operation_throttler.Builder(); builder.async_operation_throttler(throttleBuilder); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorServerProducer.java b/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorServerProducer.java index 1865db0ec1c..81fc19de929 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorServerProducer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorServerProducer.java @@ -29,7 +29,6 @@ public class StorServerProducer implements StorServerConfig.Producer { private final String clusterName; private Integer maxMergesPerNode; private Integer queueSize; - private Long mergingMaxMemoryUsagePerNode; private StorServerProducer setMaxMergesPerNode(Integer value) { if (value != null) { @@ -46,7 +45,6 @@ public class StorServerProducer implements StorServerConfig.Producer { StorServerProducer(String clusterName, ModelContext.FeatureFlags featureFlags) { this.clusterName = clusterName; - this.mergingMaxMemoryUsagePerNode = featureFlags.mergingMaxMemoryUsagePerNode(); } @Override @@ -63,10 +61,8 @@ public class StorServerProducer implements StorServerConfig.Producer { if (queueSize != null) { builder.max_merge_queue_size(queueSize); } - if (mergingMaxMemoryUsagePerNode != null) { - builder.merge_throttling_memory_limit( - new StorServerConfig.Merge_throttling_memory_limit.Builder() - .max_usage_bytes(mergingMaxMemoryUsagePerNode)); - } + builder.merge_throttling_memory_limit( + new StorServerConfig.Merge_throttling_memory_limit.Builder() + .max_usage_bytes(0)); } } diff --git a/config-model/src/main/resources/schema/services.rnc b/config-model/src/main/resources/schema/services.rnc index 1c30b2d91f9..03d4ee80683 100644 --- a/config-model/src/main/resources/schema/services.rnc +++ b/config-model/src/main/resources/schema/services.rnc @@ -7,6 +7,7 @@ include "containercluster.rnc" start = element services { attribute version { "1.0" }? & + attribute minimum-required-vespa-version { text }? & attribute application-type { "hosted-infrastructure" }? & element legacy { element v7-geo-positions { xsd:boolean } }? & GenericConfig* & diff --git a/config-model/src/test/derived/rankprofileinheritance/child.sd b/config-model/src/test/derived/rankprofileinheritance/child.sd index 2517d0731f5..8348a62838c 100644 --- a/config-model/src/test/derived/rankprofileinheritance/child.sd +++ b/config-model/src/test/derived/rankprofileinheritance/child.sd @@ -39,4 +39,15 @@ schema child { } + rank-profile profile5 inherits profile1 { + match-features { + attribute(field3) + } + } + + rank-profile profile6 inherits profile1 { + summary-features { } + match-features { } + } + } diff --git a/config-model/src/test/derived/rankprofileinheritance/rank-profiles.cfg b/config-model/src/test/derived/rankprofileinheritance/rank-profiles.cfg index a3bc6791412..ccf52da3b5e 100644 --- a/config-model/src/test/derived/rankprofileinheritance/rank-profiles.cfg +++ b/config-model/src/test/derived/rankprofileinheritance/rank-profiles.cfg @@ -52,3 +52,23 @@ rankprofile[].fef.property[].name "vespa.feature.rename" rankprofile[].fef.property[].value "rankingExpression(function4)" rankprofile[].fef.property[].name "vespa.feature.rename" rankprofile[].fef.property[].value "function4" +rankprofile[].name "profile5" +rankprofile[].fef.property[].name "rankingExpression(function1).rankingScript" +rankprofile[].fef.property[].value "attribute(field1) + 5" +rankprofile[].fef.property[].name "rankingExpression(function1b).rankingScript" +rankprofile[].fef.property[].value "attribute(field1) + 42" +rankprofile[].fef.property[].name "vespa.summary.feature" +rankprofile[].fef.property[].value "attribute(field1)" +rankprofile[].fef.property[].name "vespa.summary.feature" +rankprofile[].fef.property[].value "rankingExpression(function1)" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "attribute(field3)" +rankprofile[].fef.property[].name "vespa.feature.rename" +rankprofile[].fef.property[].value "rankingExpression(function1)" +rankprofile[].fef.property[].name "vespa.feature.rename" +rankprofile[].fef.property[].value "function1" +rankprofile[].name "profile6" +rankprofile[].fef.property[].name "rankingExpression(function1).rankingScript" +rankprofile[].fef.property[].value "attribute(field1) + 5" +rankprofile[].fef.property[].name "rankingExpression(function1b).rankingScript" +rankprofile[].fef.property[].value "attribute(field1) + 42" diff --git a/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java b/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java index c959634019d..e920672646f 100644 --- a/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java @@ -1,11 +1,17 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.document.Document; import com.yahoo.schema.document.Stemming; import com.yahoo.schema.parser.ParseException; import com.yahoo.schema.processing.ImportedFieldsResolver; import com.yahoo.schema.processing.OnnxModelTypeResolver; import com.yahoo.vespa.documentmodel.DocumentSummary; +import com.yahoo.vespa.indexinglanguage.expressions.AttributeExpression; +import com.yahoo.vespa.indexinglanguage.expressions.Expression; +import com.yahoo.vespa.indexinglanguage.expressions.InputExpression; +import com.yahoo.vespa.indexinglanguage.expressions.ScriptExpression; +import com.yahoo.vespa.indexinglanguage.expressions.StatementExpression; import com.yahoo.vespa.model.test.utils.DeployLoggerStub; import org.junit.jupiter.api.Test; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsConsumersTest.java b/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsConsumersTest.java index becb7235c64..88e1ba7a1a6 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsConsumersTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsConsumersTest.java @@ -23,6 +23,7 @@ import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.g import static com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyModelTester.servicesWithAdminOnly; import static java.util.Collections.singleton; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -69,15 +70,21 @@ public class MetricsConsumersTest { @Test void vespa_consumer_can_be_amended_via_admin_object() { - VespaModel model = getModel(servicesWithAdminOnly(), self_hosted); + VespaModel model = getModel(servicesWithAdminOnly(), hosted); var additionalMetric = new Metric("additional-metric"); model.getAdmin().setAdditionalDefaultMetrics(new MetricSet("amender-metrics", singleton(additionalMetric))); ConsumersConfig config = consumersConfigFromModel(model); assertEquals(numMetricsForVespaConsumer + 1, config.consumer(0).metric().size()); - ConsumersConfig.Consumer vespaConsumer = config.consumer(0); + ConsumersConfig.Consumer vespaConsumer = requireConsumer(config, MetricsConsumer.vespa); assertTrue(checkMetric(vespaConsumer, additionalMetric), "Did not contain additional metric"); + + ConsumersConfig.Consumer defaultConsumer = requireConsumer(config, MetricsConsumer.defaultConsumer); + assertFalse(checkMetric(defaultConsumer, additionalMetric), "Contained additional metric"); + + ConsumersConfig.Consumer vespa9Consumer = requireConsumer(config, MetricsConsumer.vespa9); + assertTrue(checkMetric(vespa9Consumer, additionalMetric), "Did not contain additional metric"); } @Test @@ -249,4 +256,11 @@ public class MetricsConsumersTest { assertTrue(checkMetric(consumer, customMetric), "Did not contain metric: " + customMetric); } + private ConsumersConfig.Consumer requireConsumer(ConsumersConfig config, MetricsConsumer consumer) { + return config.consumer() + .stream() + .filter(c -> c.name().equals(consumer.id())) + .findFirst().orElseThrow(); + } + } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java index a53ef233746..52d861ac902 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java @@ -29,8 +29,6 @@ import java.util.concurrent.atomic.AtomicLong; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; /** * @author bjorncs @@ -119,7 +117,6 @@ class JvmHeapSizeValidatorTest { ModelCostDummy(long modelCost) { this.modelCost = modelCost; } - @Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; } @Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId, ClusterSpec.Id clusterId) { return this; } @Override public Map<String, ModelInfo> models() { return Map.of(); } @Override public void setRestartOnDeploy() {} diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/StreamingValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/StreamingValidatorTest.java index 6f66838ba47..5397c30f2bc 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/StreamingValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/StreamingValidatorTest.java @@ -33,12 +33,16 @@ public class StreamingValidatorTest { "Document references and imported fields are not allowed in streaming search.")); } + private static List<String> filter(List<String> warnings) { + return warnings.stream().filter(x -> x.indexOf("Cannot run program") == -1).toList(); + } + @Test void tensor_field_without_index_gives_no_warning() { var logger = new TestableDeployLogger(); var model = createModel(logger, "field nn type tensor(x[2]) { indexing: attribute | summary\n" + "attribute { distance-metric: euclidean } }"); - assertTrue(logger.warnings.isEmpty()); + assertTrue(filter(logger.warnings).isEmpty()); } @Test @@ -46,9 +50,10 @@ public class StreamingValidatorTest { var logger = new TestableDeployLogger(); var model = createModel(logger, "field nn type tensor(x[2]) { indexing: attribute | index | summary\n" + "attribute { distance-metric: euclidean } }"); - assertEquals(1, logger.warnings.size()); + var warnings = filter(logger.warnings); + assertEquals(1, warnings.size()); assertEquals("For streaming search cluster 'content.test', SD field 'nn': hnsw index is not relevant and not supported, ignoring setting", - logger.warnings.get(0)); + warnings.get(0)); } private static VespaModel createModel(DeployLogger logger, String sdContent) { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/VespaDomBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/VespaDomBuilderTest.java index 66a64681c60..2d5b1a307cd 100755 --- a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/VespaDomBuilderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/VespaDomBuilderTest.java @@ -17,6 +17,7 @@ import java.io.StringReader; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; /** * @author gjoranv @@ -99,6 +100,20 @@ public class VespaDomBuilderTest { assertEquals("hosts [" + host.getHostname() + "]", hostSystem.toString()); } + @Test + void testMinimumRequiredVespaVersion() { + var exception = assertThrows(IllegalArgumentException.class, + () -> createModel(hosts, """ + <services minimum-required-vespa-version='1.0.1' > + </services>""")); + assertEquals("Cannot deploy application, minimum required Vespa version is specified as 1.0.1 in services.xml, this Vespa version is 1.0.0.", + exception.getMessage()); + + createModel(hosts, """ + <services minimum-required-vespa-version='1.0.0' > + </services>"""); + } + private VespaModel createModel(String hosts, String services) { VespaModelCreatorWithMockPkg creator = new VespaModelCreatorWithMockPkg(hosts, services); return creator.create(); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilterTest.java index c89ea421b39..1c5eb16be80 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilterTest.java @@ -16,7 +16,6 @@ import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.Zone; import com.yahoo.jdisc.http.ConnectorConfig; import com.yahoo.jdisc.http.filter.security.cloud.config.CloudTokenDataPlaneFilterConfig; -import com.yahoo.processing.response.Data; import com.yahoo.vespa.model.container.ApplicationContainer; import com.yahoo.vespa.model.container.ContainerModel; import com.yahoo.vespa.model.container.http.ConnectorFactory; @@ -41,14 +40,14 @@ import static com.yahoo.vespa.model.container.xml.CloudDataPlaneFilterTest.creat import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; public class CloudTokenDataPlaneFilterTest extends ContainerModelBuilderTestBase { private static final String servicesXmlTemplate = """ <container version='1.0'> <clients> - <client id="foo" permissions="read,write"> + <client id="foo" permissions="read, write"> <certificate file="%s"/> </client> <client id="bar" permissions="read"> @@ -145,6 +144,24 @@ public class CloudTokenDataPlaneFilterTest extends ContainerModelBuilderTestBase } + @Test + void fails_on_unknown_permission() throws IOException { + var certFile = securityFolder.resolve("foo.pem"); + var servicesXml = """ + <container version='1.0'> + <clients> + <client id="foo" permissions="read,unknown-permission"> + <certificate file="%s"/> + </client> + </clients> + </container> + """.formatted(applicationFolder.toPath().relativize(certFile).toString()); + var clusterElem = DomBuilderTest.parse(servicesXml); + createCertificate(certFile); + var exception = assertThrows(IllegalArgumentException.class, () -> buildModel(Set.of(mtlsEndpoint), defaultTokens, clusterElem)); + assertEquals("Invalid permission 'unknown-permission'. Valid values are 'read' and 'write'.", exception.getMessage()); + } + private static CloudTokenDataPlaneFilterConfig.Clients.Tokens tokenConfig( String id, Collection<String> fingerprints, Collection<String> accessCheckHashes, Collection<String> expirations) { return new CloudTokenDataPlaneFilterConfig.Clients.Tokens.Builder() diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/DistributorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/DistributorTest.java index 0e692a7ff07..b54aefa04bc 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/DistributorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/DistributorTest.java @@ -52,20 +52,6 @@ public class DistributorTest { } @Test - void testRevertDefaultOffForSearch() { - StorDistributormanagerConfig.Builder builder = new StorDistributormanagerConfig.Builder(); - parse("<cluster id=\"storage\">\n" + - " <redundancy>3</redundancy>" + - " <documents/>" + - " <group>" + - " <node distribution-key=\"0\" hostalias=\"mockhost\"/>" + - " </group>" + - "</cluster>").getConfig(builder); - StorDistributormanagerConfig conf = new StorDistributormanagerConfig(builder); - assertFalse(conf.enable_revert()); - } - - @Test void testSplitAndJoin() { StorDistributormanagerConfig.Builder builder = new StorDistributormanagerConfig.Builder(); parse("<cluster id=\"storage\">\n" + diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/StorageClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/StorageClusterTest.java index bdd61d93136..7c3e66aa109 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/StorageClusterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/StorageClusterTest.java @@ -189,7 +189,7 @@ public class StorageClusterTest { var config = configFromProperties(new TestProperties()); var limit = config.merge_throttling_memory_limit(); - assertEquals(-1L, limit.max_usage_bytes()); // TODO change default + assertEquals(0L, limit.max_usage_bytes()); assertMergeAutoScaleConfigHasExpectedValues(limit); } @@ -200,21 +200,6 @@ public class StorageClusterTest { } @Test - void merge_throttler_memory_limit_is_controlled_by_feature_flag() { - var config = configFromProperties(new TestProperties().setMergingMaxMemoryUsagePerNode(-1)); - assertEquals(-1L, config.merge_throttling_memory_limit().max_usage_bytes()); - - config = configFromProperties(new TestProperties().setMergingMaxMemoryUsagePerNode(0)); - assertEquals(0L, config.merge_throttling_memory_limit().max_usage_bytes()); - - config = configFromProperties(new TestProperties().setMergingMaxMemoryUsagePerNode(1_234_456_789)); - assertEquals(1_234_456_789L, config.merge_throttling_memory_limit().max_usage_bytes()); - - // Feature flag should not affect the other config values - assertMergeAutoScaleConfigHasExpectedValues(config.merge_throttling_memory_limit()); - } - - @Test void testVisitors() { StorVisitorConfig.Builder builder = new StorVisitorConfig.Builder(); parse(cluster("bees", @@ -355,24 +340,6 @@ public class StorageClusterTest { assertTrue(config.async_operation_throttler().throttle_individual_merge_feed_ops()); } - private void verifyUsePerDocumentThrottledDeleteBucket(boolean expected, Boolean enabled) { - var props = new TestProperties(); - if (enabled != null) { - props.setUsePerDocumentThrottledDeleteBucket(enabled); - } - var config = filestorConfigFromProducer(simpleCluster(props)); - assertEquals(expected, config.use_per_document_throttled_delete_bucket()); - } - - @Test - void delete_bucket_throttling_is_controlled_by_feature_flag() { - // TODO update default once rolled out and tested - verifyUsePerDocumentThrottledDeleteBucket(false, null); - - verifyUsePerDocumentThrottledDeleteBucket(false, false); - verifyUsePerDocumentThrottledDeleteBucket(true, true); - } - @Test void testCapacity() { String xml = joinLines( diff --git a/config-model/src/test/schema-test-files/services-hosted-infrastructure.xml b/config-model/src/test/schema-test-files/services-hosted-infrastructure.xml index b1711906086..9144b1ad0f8 100644 --- a/config-model/src/test/schema-test-files/services-hosted-infrastructure.xml +++ b/config-model/src/test/schema-test-files/services-hosted-infrastructure.xml @@ -1,6 +1,6 @@ <?xml version="1.0" encoding="utf-8" ?> <!-- Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> -<services version="1.0" application-type="hosted-infrastructure"> +<services version="1.0" application-type="hosted-infrastructure" minimum-required-vespa-version="8.0.0"> <admin version="4.0"> <slobroks><nodes count="3" flavor="small"/></slobroks> diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/CloudAccount.java b/config-provisioning/src/main/java/com/yahoo/config/provision/CloudAccount.java index 88583fc2007..5111cfb2704 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/CloudAccount.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/CloudAccount.java @@ -18,6 +18,7 @@ public class CloudAccount implements Comparable<CloudAccount> { } private static final Map<String, CloudMeta> META_BY_CLOUD = Map.of( "aws", new CloudMeta("Account ID", Pattern.compile("[0-9]{12}")), + "azure", new CloudMeta("Subscription ID", Pattern.compile("[0-9a-f]{8}-([0-9a-f]{4}-){3}[0-9a-f]{12}")), "gcp", new CloudMeta("Project ID", Pattern.compile("[a-z][a-z0-9-]{4,28}[a-z0-9]"))); /** Empty value. When this is used, either implicitly or explicitly, the zone will use its default account */ diff --git a/config-provisioning/src/test/java/com/yahoo/config/provision/CloudAccountTest.java b/config-provisioning/src/test/java/com/yahoo/config/provision/CloudAccountTest.java index c9230feaa6d..5af9cdb9263 100644 --- a/config-provisioning/src/test/java/com/yahoo/config/provision/CloudAccountTest.java +++ b/config-provisioning/src/test/java/com/yahoo/config/provision/CloudAccountTest.java @@ -46,6 +46,14 @@ class CloudAccountTest { } @Test + void azure_account() { + CloudAccount account = CloudAccount.from("azure:248ace13-1234-abcd-89ad-123456789abc"); + assertEquals("248ace13-1234-abcd-89ad-123456789abc", account.account()); + assertEquals(CloudName.AZURE, account.cloudName()); + assertEquals("azure:248ace13-1234-abcd-89ad-123456789abc", account.value()); + } + + @Test void default_accounts() { CloudAccount variant1 = CloudAccount.from(""); CloudAccount variant2 = CloudAccount.from("default"); @@ -65,7 +73,7 @@ class CloudAccountTest { assertInvalidAccount("aws:123", "Invalid cloud account 'aws:123': Account ID must match '[0-9]{12}'"); assertInvalidAccount("gcp:123", "Invalid cloud account 'gcp:123': Project ID must match '[a-z][a-z0-9-]{4,28}[a-z0-9]'"); assertInvalidAccount("$something", "Invalid cloud account '$something': Must be on format '<cloud-name>:<account>' or 'default'"); - assertInvalidAccount("unknown:account", "Invalid cloud account 'unknown:account': Cloud name must be one of: aws, gcp"); + assertInvalidAccount("unknown:account", "Invalid cloud account 'unknown:account': Cloud name must be one of: aws, azure, gcp"); } private static void assertInvalidAccount(String account, String message) { diff --git a/configdefinitions/src/vespa/stor-filestor.def b/configdefinitions/src/vespa/stor-filestor.def index 090f74dec12..cefce5fc648 100644 --- a/configdefinitions/src/vespa/stor-filestor.def +++ b/configdefinitions/src/vespa/stor-filestor.def @@ -1,14 +1,6 @@ # Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. namespace=vespa.config.content -## DETECT FAILURE PARAMETERS - -## Deprecated and unused - will soon go away -fail_disk_after_error_count int default=1 restart - -## Deprecated and unused - will soon go away -disk_operation_timeout int default=0 restart - ## PERFORMANCE PARAMETERS ## Number of threads to use for each mountpoint. @@ -42,15 +34,6 @@ common_merge_chain_optimalization_minimum_size int default=64 restart ## Should follow stor-distributormanager:splitsize (16MB). bucket_merge_chunk_size int default=16772216 restart -## When merging, it is possible to send more metadata than needed in order to -## let local nodes in merge decide which entries fits best to add this time -## based on disk location. Toggle this option on to use it. Note that memory -## consumption might increase in a 4.1 to 4.2 upgrade due to this, as 4.1 -## dont support to only fill in part of the metadata provided and will always -## fill all. -## NB unused and will be removed shortly. -enable_merge_local_node_choose_docs_optimalization bool default=true restart - ## Whether or not to enable the multibit split optimalization. This is useful ## if splitting is expensive, but listing document identifiers is fairly cheap. ## This is true for memfile persistence layer, but not for vespa search. @@ -93,32 +76,3 @@ async_operation_throttler.resize_rate double default=3.0 ## level, i.e. per ApplyBucketDiff message, regardless of how many document operations ## are contained within. async_operation_throttler.throttle_individual_merge_feed_ops bool default=true - -## Specify throttling used for async persistence operations. This throttling takes place -## before operations are dispatched to Proton and serves as a limiter for how many -## operations may be in flight in Proton's internal queues. -## -## - UNLIMITED is, as it says on the tin, unlimited. Offers no actual throttling, but -## has near zero overhead and never blocks. -## - DYNAMIC uses DynamicThrottlePolicy under the hood and will block if the window -## is full (if a blocking throttler API call is invoked). -## -## TODO deprecate in favor of the async_operation_throttler struct instead. -async_operation_throttler_type enum { UNLIMITED, DYNAMIC } default=DYNAMIC - -## Specifies the extent the throttling window is increased by when the async throttle -## policy has decided that more concurrent operations are desirable. Also affects the -## _minimum_ size of the throttling window; its size is implicitly set to max(this config -## value, number of threads). -## -## Only applies if async_operation_throttler_type == DYNAMIC. -## DEPRECATED! use the async_operation_throttler struct instead -async_operation_dynamic_throttling_window_increment int default=20 restart - -## If set, DeleteBucket operations are internally expanded to an individually persistence- -## throttled remove per document stored in the bucket. This makes the cost model of -## executing a DeleteBucket symmetrical with feeding the documents to the bucket in the -## first place. -## -## This is a live config. -use_per_document_throttled_delete_bucket bool default=false diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java index 64e5f80d72c..32f4d2b653c 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java @@ -15,6 +15,7 @@ import com.yahoo.config.model.api.HostInfo; import com.yahoo.config.model.api.ServiceInfo; import com.yahoo.config.provision.ActivationContext; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ApplicationLockException; import com.yahoo.config.provision.ApplicationTransaction; import com.yahoo.config.provision.Capacity; import com.yahoo.config.provision.EndpointsChecker; @@ -24,6 +25,7 @@ import com.yahoo.config.provision.EndpointsChecker.HealthCheckerProvider; import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.HostFilter; import com.yahoo.config.provision.InfraDeployer; +import com.yahoo.config.provision.ParentHostUnavailableException; import com.yahoo.config.provision.Provisioner; import com.yahoo.config.provision.RegionName; import com.yahoo.config.provision.SystemName; @@ -41,6 +43,7 @@ import com.yahoo.slime.Slime; import com.yahoo.transaction.NestedTransaction; import com.yahoo.transaction.Transaction; import com.yahoo.vespa.applicationmodel.InfrastructureApplication; +import com.yahoo.vespa.config.server.application.ActiveTokenFingerprints; import com.yahoo.vespa.config.server.application.ActiveTokenFingerprints.Token; import com.yahoo.vespa.config.server.application.ActiveTokenFingerprintsClient; import com.yahoo.vespa.config.server.application.Application; @@ -52,7 +55,6 @@ import com.yahoo.vespa.config.server.application.ClusterReindexing; import com.yahoo.vespa.config.server.application.ClusterReindexingStatusClient; import com.yahoo.vespa.config.server.application.CompressedApplicationInputStream; import com.yahoo.vespa.config.server.application.ConfigConvergenceChecker; -import com.yahoo.vespa.config.server.application.ActiveTokenFingerprints; import com.yahoo.vespa.config.server.application.DefaultClusterReindexingStatusClient; import com.yahoo.vespa.config.server.application.FileDistributionStatus; import com.yahoo.vespa.config.server.application.HttpProxy; @@ -71,6 +73,7 @@ import com.yahoo.vespa.config.server.http.LogRetriever; import com.yahoo.vespa.config.server.http.SecretStoreValidator; import com.yahoo.vespa.config.server.http.SimpleHttpFetcher; import com.yahoo.vespa.config.server.http.TesterClient; +import com.yahoo.vespa.config.server.http.v2.PrepareAndActivateResult; import com.yahoo.vespa.config.server.http.v2.PrepareResult; import com.yahoo.vespa.config.server.http.v2.response.DeploymentMetricsResponse; import com.yahoo.vespa.config.server.http.v2.response.SearchNodeMetricsResponse; @@ -108,7 +111,6 @@ import java.time.Duration; import java.time.Instant; import java.util.Collection; import java.util.Comparator; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -132,7 +134,6 @@ import static com.yahoo.vespa.config.server.tenant.TenantRepository.HOSTED_VESPA import static com.yahoo.vespa.curator.Curator.CompletionWaiter; import static com.yahoo.yolean.Exceptions.uncheck; import static java.nio.file.Files.readAttributes; -import static java.util.stream.Collectors.toMap; /** * The API for managing applications. @@ -363,36 +364,40 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye return deployment; } - public PrepareResult deploy(CompressedApplicationInputStream in, PrepareParams prepareParams) { + public PrepareAndActivateResult deploy(CompressedApplicationInputStream in, PrepareParams prepareParams) { DeployHandlerLogger logger = DeployHandlerLogger.forPrepareParams(prepareParams); File tempDir = uncheck(() -> Files.createTempDirectory("deploy")).toFile(); ThreadLockStats threadLockStats = LockStats.getForCurrentThread(); - PrepareResult prepareResult; + PrepareAndActivateResult result; try { threadLockStats.startRecording("deploy of " + prepareParams.getApplicationId().serializedForm()); - prepareResult = deploy(decompressApplication(in, tempDir), prepareParams, logger); + result = deploy(decompressApplication(in, tempDir), prepareParams, logger); } finally { threadLockStats.stopRecording(); cleanupTempDirectory(tempDir, logger); } - return prepareResult; + return result; } public PrepareResult deploy(File applicationPackage, PrepareParams prepareParams) { - return deploy(applicationPackage, prepareParams, DeployHandlerLogger.forPrepareParams(prepareParams)); + return deploy(applicationPackage, prepareParams, DeployHandlerLogger.forPrepareParams(prepareParams)).deployResult(); } - private PrepareResult deploy(File applicationDir, PrepareParams prepareParams, DeployHandlerLogger logger) { + private PrepareAndActivateResult deploy(File applicationDir, PrepareParams prepareParams, DeployHandlerLogger logger) { long sessionId = createSession(prepareParams.getApplicationId(), prepareParams.getTimeoutBudget(), applicationDir, logger); Deployment deployment = prepare(sessionId, prepareParams, logger); - if ( ! prepareParams.isDryRun()) + RuntimeException activationFailure = null; + if ( ! prepareParams.isDryRun()) try { deployment.activate(); - - return new PrepareResult(sessionId, deployment.configChangeActions(), logger); + } + catch (ParentHostUnavailableException | ApplicationLockException e) { + activationFailure = e; + } + return new PrepareAndActivateResult(new PrepareResult(sessionId, deployment.configChangeActions(), logger), activationFailure); } /** @@ -537,7 +542,7 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye static void checkIfActiveHasChanged(Session session, Session activeSession, boolean ignoreStaleSessionFailure) { long activeSessionAtCreate = session.getActiveSessionAtCreate(); log.log(Level.FINE, () -> activeSession.logPre() + "active session id at create time=" + activeSessionAtCreate); - if (activeSessionAtCreate == 0) return; // No active session at create time + if (activeSessionAtCreate == 0) return; // No active session at create time, or session created for indeterminate app. long sessionId = session.getSessionId(); long activeSessionSessionId = activeSession.getSessionId(); @@ -545,10 +550,10 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye ", current active session=" + activeSessionSessionId); if (activeSession.isNewerThan(activeSessionAtCreate) && activeSessionSessionId != sessionId) { - String errMsg = activeSession.logPre() + "Cannot activate session " + - sessionId + " because the currently active session (" + - activeSessionSessionId + ") has changed since session " + sessionId + - " was created (was " + activeSessionAtCreate + " at creation time)"; + String errMsg = activeSession.logPre() + "Cannot activate session " + sessionId + + " because the currently active session (" + activeSessionSessionId + + ") has changed since session " + sessionId + " was created (was " + + activeSessionAtCreate + " at creation time)"; if (ignoreStaleSessionFailure) { log.warning(errMsg + " (Continuing because of force.)"); } else { @@ -950,30 +955,25 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye } public void deleteExpiredLocalSessions() { - Map<Tenant, Collection<LocalSession>> sessionsPerTenant = new HashMap<>(); - tenantRepository.getAllTenants() - .forEach(tenant -> sessionsPerTenant.put(tenant, tenant.getSessionRepository().getLocalSessions())); - - Set<ApplicationId> applicationIds = new HashSet<>(); - sessionsPerTenant.values() - .forEach(sessionList -> sessionList.stream() - .map(Session::getOptionalApplicationId) - .filter(Optional::isPresent) - .forEach(appId -> applicationIds.add(appId.get()))); - - Map<ApplicationId, Long> activeSessions = new HashMap<>(); - applicationIds.forEach(applicationId -> getActiveSession(applicationId).ifPresent(session -> activeSessions.put(applicationId, session.getSessionId()))); - sessionsPerTenant.keySet().forEach(tenant -> tenant.getSessionRepository().deleteExpiredSessions(activeSessions)); + for (Tenant tenant : tenantRepository.getAllTenants()) { + tenant.getSessionRepository().deleteExpiredSessions(session -> sessionIsActiveForItsApplication(tenant, session)); + } } public int deleteExpiredRemoteSessions(Clock clock) { return tenantRepository.getAllTenants() .stream() - .map(tenant -> tenant.getSessionRepository().deleteExpiredRemoteSessions(clock)) + .map(tenant -> tenant.getSessionRepository().deleteExpiredRemoteSessions(clock, session -> sessionIsActiveForItsApplication(tenant, session))) .mapToInt(i -> i) .sum(); } + private boolean sessionIsActiveForItsApplication(Tenant tenant, Session session) { + Optional<ApplicationId> owner = session.getOptionalApplicationId(); + if (owner.isEmpty()) return true; // Chicken out ~(˘▾˘)~ + return tenant.getApplicationRepo().activeSessionOf(owner.get()).equals(Optional.of(session.getSessionId())); + } + // ---------------- Tenant operations ---------------------------------------------------------------- diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationMapper.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationMapper.java index de86e9a9cdc..6b1a75f2f44 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationMapper.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationMapper.java @@ -80,7 +80,8 @@ public final class ApplicationMapper { } public List<Application> listApplications(ApplicationId applicationId) { - return requestHandlers.get(applicationId).applications(); + var applicationVersions = requestHandlers.get(applicationId); + return applicationVersions == null ? List.of() : applicationVersions.applications(); } } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java index 7c7608e5e5c..c219d59a44e 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java @@ -177,7 +177,7 @@ public class Deployment implements com.yahoo.config.provision.Deployment { nodesToRestart.size(), nodesToRestart.stream().sorted().collect(joining(", ")))); log.info(String.format("%sWill schedule service restart of %d nodes after convergence on generation %d: %s", session.logPre(), nodesToRestart.size(), session.getSessionId(), nodesToRestart.stream().sorted().collect(joining(", ")))); - this.configChangeActions = configChangeActions.withRestartActions(new RestartActions()); + configChangeActions = configChangeActions == null ? null : configChangeActions.withRestartActions(new RestartActions()); } private void storeReindexing(ApplicationId applicationId) { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java index 22b2b581b44..d500e56d079 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java @@ -201,13 +201,11 @@ public class ModelContextImpl implements ModelContext { private final int rpc_events_before_wakeup; private final int heapPercentage; private final String summaryDecodePolicy; - private boolean sortBlueprintsByCost; + private final boolean sortBlueprintsByCost; private final boolean alwaysMarkPhraseExpensive; private final int contentLayerMetadataFeatureLevel; private final String unknownConfigDefinition; private final int searchHandlerThreadpool; - private final long mergingMaxMemoryUsagePerNode; - private final boolean usePerDocumentThrottledDeleteBucket; private final boolean restartOnDeployWhenOnnxModelChanges; public FeatureFlags(FlagSource source, ApplicationId appId, Version version) { @@ -249,8 +247,6 @@ public class ModelContextImpl implements ModelContext { this.contentLayerMetadataFeatureLevel = flagValue(source, appId, version, Flags.CONTENT_LAYER_METADATA_FEATURE_LEVEL); this.unknownConfigDefinition = flagValue(source, appId, version, Flags.UNKNOWN_CONFIG_DEFINITION); this.searchHandlerThreadpool = flagValue(source, appId, version, Flags.SEARCH_HANDLER_THREADPOOL); - this.mergingMaxMemoryUsagePerNode = flagValue(source, appId, version, Flags.MERGING_MAX_MEMORY_USAGE_PER_NODE); - this.usePerDocumentThrottledDeleteBucket = flagValue(source, appId, version, Flags.USE_PER_DOCUMENT_THROTTLED_DELETE_BUCKET); this.alwaysMarkPhraseExpensive = flagValue(source, appId, version, Flags.ALWAYS_MARK_PHRASE_EXPENSIVE); this.restartOnDeployWhenOnnxModelChanges = flagValue(source, appId, version, Flags.RESTART_ON_DEPLOY_WHEN_ONNX_MODEL_CHANGES); this.sortBlueprintsByCost = flagValue(source, appId, version, Flags.SORT_BLUEPRINTS_BY_COST); @@ -303,8 +299,6 @@ public class ModelContextImpl implements ModelContext { @Override public int contentLayerMetadataFeatureLevel() { return contentLayerMetadataFeatureLevel; } @Override public String unknownConfigDefinition() { return unknownConfigDefinition; } @Override public int searchHandlerThreadpool() { return searchHandlerThreadpool; } - @Override public long mergingMaxMemoryUsagePerNode() { return mergingMaxMemoryUsagePerNode; } - @Override public boolean usePerDocumentThrottledDeleteBucket() { return usePerDocumentThrottledDeleteBucket; } @Override public boolean restartOnDeployWhenOnnxModelChanges() { return restartOnDeployWhenOnnxModelChanges; } @Override public boolean sortBlueprintsByCost() { return sortBlueprintsByCost; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java index 29f2125ac3c..737fb787937 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java @@ -3,15 +3,22 @@ package com.yahoo.vespa.config.server.http.v2; import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.component.annotation.Inject; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ApplicationLockException; +import com.yahoo.config.provision.ParentHostUnavailableException; import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.Zone; +import com.yahoo.config.provision.zone.ZoneId; import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.container.jdisc.utils.MultiPartFormParser; import com.yahoo.container.jdisc.utils.MultiPartFormParser.PartItem; import com.yahoo.jdisc.application.BindingMatch; import com.yahoo.jdisc.http.HttpHeaders; +import com.yahoo.restapi.MessageResponse; +import com.yahoo.restapi.SlimeJsonResponse; import com.yahoo.vespa.config.server.ApplicationRepository; +import com.yahoo.vespa.config.server.TimeoutBudget; import com.yahoo.vespa.config.server.application.CompressedApplicationInputStream; import com.yahoo.vespa.config.server.http.BadRequestException; import com.yahoo.vespa.config.server.http.SessionHandler; @@ -56,8 +63,8 @@ public class ApplicationApiHandler extends SessionHandler { private final TenantRepository tenantRepository; private final Duration zookeeperBarrierTimeout; - private final Zone zone; private final long maxApplicationPackageSize; + private final Zone zone; @Inject public ApplicationApiHandler(Context ctx, @@ -72,6 +79,17 @@ public class ApplicationApiHandler extends SessionHandler { } @Override + protected HttpResponse handlePUT(HttpRequest request) { + TenantName tenantName = validateTenant(request); + long sessionId = getSessionIdFromRequest(request); + ApplicationId app = applicationRepository.activate(tenantRepository.getTenant(tenantName), + sessionId, + getTimeoutBudget(request, Duration.ofMinutes(2)), + shouldIgnoreSessionStaleFailure(request)); + return new MessageResponse("Session " + sessionId + " for " + app.toFullString() + " activated"); + } + + @Override protected HttpResponse handlePOST(HttpRequest request) { validateDataAndHeader(request, List.of(APPLICATION_X_GZIP, APPLICATION_ZIP, MULTIPART_FORM_DATA)); TenantName tenantName = validateTenant(request); @@ -112,8 +130,8 @@ public class ApplicationApiHandler extends SessionHandler { .ifPresent(e -> e.addKeyValue("app.id", prepareParams.getApplicationId().toFullString())); try (compressedStream) { - PrepareResult result = applicationRepository.deploy(compressedStream, prepareParams); - return new SessionPrepareAndActivateResponse(result, request, prepareParams.getApplicationId(), zone); + PrepareAndActivateResult result = applicationRepository.deploy(compressedStream, prepareParams); + return new SessionPrepareAndActivateResponse(result, prepareParams.getApplicationId(), request, zone); } catch (IOException e) { throw new UncheckedIOException(e); @@ -132,8 +150,18 @@ public class ApplicationApiHandler extends SessionHandler { } public static TenantName getTenantNameFromRequest(HttpRequest request) { - BindingMatch<?> bm = Utils.getBindingMatch(request, "http://*/application/v2/tenant/*/prepareandactivate*"); + BindingMatch<?> bm = Utils.getBindingMatch(request, "http://*/application/v2/tenant/*/prepareandactivate*"); // Gosh, these glob rules aren't good ... return TenantName.from(bm.group(2)); } + public static long getSessionIdFromRequest(HttpRequest request) { + BindingMatch<?> bm = Utils.getBindingMatch(request, "http://*/application/v2/tenant/*/prepareandactivate/*"); + try { + return Long.parseLong(bm.group(3)); + } + catch (NumberFormatException e) { + throw new BadRequestException("Session id '" + bm.group(3) + "' is not a number: " + e.getMessage()); + } + } + } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/PrepareAndActivateResult.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/PrepareAndActivateResult.java new file mode 100644 index 00000000000..3da3a6752cd --- /dev/null +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/PrepareAndActivateResult.java @@ -0,0 +1,21 @@ +package com.yahoo.vespa.config.server.http.v2; + +import com.yahoo.config.provision.ParentHostUnavailableException; + +/** + * Allows a partial deployment success, where the application is prepared, but not activated. + * This currently only allows the parent-host-not-ready and application-lock cases, as other transient errors are + * thrown too early (LB during prepare, cert during validation), but could be expanded to allow + * reuse of a prepared session in the future. In that case, users of this result (handler and its client) + * must also be updated. + * + * @author jonmv + */ +public record PrepareAndActivateResult(PrepareResult prepareResult, RuntimeException activationFailure) { + + public PrepareResult deployResult() { + if (activationFailure != null) throw activationFailure; + return prepareResult; + } + +} diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/response/SessionPrepareAndActivateResponse.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/response/SessionPrepareAndActivateResponse.java index 1e6f7dfe45e..f5d8efbe4e5 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/response/SessionPrepareAndActivateResponse.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/response/SessionPrepareAndActivateResponse.java @@ -8,7 +8,7 @@ import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.restapi.SlimeJsonResponse; import com.yahoo.slime.Cursor; import com.yahoo.vespa.config.server.configchange.ConfigChangeActionsSlimeConverter; -import com.yahoo.vespa.config.server.http.v2.PrepareResult; +import com.yahoo.vespa.config.server.http.v2.PrepareAndActivateResult; /** * Creates a response for ApplicationApiHandler. @@ -17,17 +17,17 @@ import com.yahoo.vespa.config.server.http.v2.PrepareResult; */ public class SessionPrepareAndActivateResponse extends SlimeJsonResponse { - public SessionPrepareAndActivateResponse(PrepareResult result, HttpRequest request, ApplicationId applicationId, Zone zone) { - super(result.deployLogger().slime()); + public SessionPrepareAndActivateResponse(PrepareAndActivateResult result, ApplicationId applicationId, HttpRequest request, Zone zone) { + super(result.prepareResult().deployLogger().slime()); TenantName tenantName = applicationId.tenant(); - String message = "Session " + result.sessionId() + " for tenant '" + tenantName.value() + "' prepared and activated."; + String message = "Session " + result.prepareResult().sessionId() + " for tenant '" + tenantName.value() + "' prepared" + + (result.activationFailure() == null ? " and activated." : ", but activation failed: " + result.activationFailure().getMessage()); Cursor root = slime.get(); - root.setString("session-id", Long.toString(result.sessionId())); root.setString("message", message); - - // TODO: remove unused fields, but add whether activation was successful. + root.setString("session-id", Long.toString(result.prepareResult().sessionId())); + root.setBool("activated", result.activationFailure() == null); root.setString("tenant", tenantName.value()); root.setString("url", "http://" + request.getHost() + ":" + request.getPort() + "/application/v2/tenant/" + tenantName + @@ -35,7 +35,8 @@ public class SessionPrepareAndActivateResponse extends SlimeJsonResponse { "/environment/" + zone.environment().value() + "/region/" + zone.region().value() + "/instance/" + applicationId.instance().value()); - new ConfigChangeActionsSlimeConverter(result.configChangeActions()).toSlime(root); + + new ConfigChangeActionsSlimeConverter(result.prepareResult().configChangeActions()).toSlime(root); } } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java index dcc5d7caa0d..76879ccf8ae 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.config.server.maintenance; import com.yahoo.config.FileReference; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.TenantName; import com.yahoo.config.subscription.ConfigSourceSet; import com.yahoo.jrt.Supervisor; import com.yahoo.jrt.Transport; @@ -19,8 +20,10 @@ import com.yahoo.vespa.filedistribution.FileReferenceDownload; import java.io.File; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.concurrent.Future; import java.util.logging.Logger; import static com.yahoo.vespa.config.server.filedistribution.FileDistributionUtil.fileReferenceExistsOnDisk; @@ -51,39 +54,63 @@ public class ApplicationPackageMaintainer extends ConfigServerMaintainer { @Override protected double maintain() { int attempts = 0; - int failures = 0; - - for (var applicationId : applicationRepository.listApplications()) { - if (shuttingDown()) - break; - - log.finest(() -> "Verifying application package for " + applicationId); - Optional<Session> session = applicationRepository.getActiveSession(applicationId); - if (session.isEmpty()) continue; // App might be deleted after call to listApplications() or not activated yet (bootstrap phase) - - Optional<FileReference> appFileReference = session.get().getApplicationPackageReference(); - if (appFileReference.isPresent()) { - long sessionId = session.get().getSessionId(); - attempts++; - if (!fileReferenceExistsOnDisk(downloadDirectory, appFileReference.get())) { - log.fine(() -> "Downloading application package with file reference " + appFileReference + - " for " + applicationId + " (session " + sessionId + ")"); - - FileReferenceDownload download = new FileReferenceDownload(appFileReference.get(), - this.getClass().getSimpleName(), - false); - if (fileDownloader.getFile(download).isEmpty()) { - failures++; - log.info("Downloading application package (" + appFileReference + ")" + - " for " + applicationId + " (session " + sessionId + ") unsuccessful. " + - "Can be ignored unless it happens many times over a long period of time, retries is expected"); + int[] failures = new int[1]; + + List<Runnable> futureDownloads = new ArrayList<>(); + for (TenantName tenantName : applicationRepository.tenantRepository().getAllTenantNames()) { + for (Session session : applicationRepository.tenantRepository().getTenant(tenantName).getSessionRepository().getRemoteSessions()) { + if (shuttingDown()) + break; + + switch (session.getStatus()) { + case PREPARE, ACTIVATE: + break; + default: continue; + } + + ApplicationId applicationId = session.getOptionalApplicationId().orElse(null); + if (applicationId == null) // dry-run sessions have no application id + continue; + + log.finest(() -> "Verifying application package for " + applicationId); + + Optional<FileReference> appFileReference = session.getApplicationPackageReference(); + if (appFileReference.isPresent()) { + long sessionId = session.getSessionId(); + attempts++; + if (!fileReferenceExistsOnDisk(downloadDirectory, appFileReference.get())) { + log.fine(() -> "Downloading application package with file reference " + appFileReference + + " for " + applicationId + " (session " + sessionId + ")"); + + FileReferenceDownload download = new FileReferenceDownload(appFileReference.get(), + this.getClass().getSimpleName(), + false); + Future<Optional<File>> futureDownload = fileDownloader.getFutureFileOrTimeout(download); + futureDownloads.add(() -> { + try { + if (futureDownload.get().isPresent()) { + createLocalSessionIfMissing(applicationId, sessionId); + return; + } + } + catch (Exception ignored) { } + failures[0]++; + log.info("Downloading application package (" + appFileReference + ")" + + " for " + applicationId + " (session " + sessionId + ") unsuccessful. " + + "Can be ignored unless it happens many times over a long period of time, retries is expected"); + }); + } + else { + createLocalSessionIfMissing(applicationId, sessionId); } } - createLocalSessionIfMissing(applicationId, sessionId); } } - return asSuccessFactorDeviation(attempts, failures); + + futureDownloads.forEach(Runnable::run); + + return asSuccessFactorDeviation(attempts, failures[0]); } private static FileDownloader createFileDownloader(ApplicationRepository applicationRepository, @@ -92,7 +119,7 @@ public class ApplicationPackageMaintainer extends ConfigServerMaintainer { List<String> otherConfigServersInCluster = getOtherConfigServersInCluster(applicationRepository.configserverConfig()); ConfigSourceSet configSourceSet = new ConfigSourceSet(otherConfigServersInCluster); ConnectionPool connectionPool = new FileDistributionConnectionPool(configSourceSet, supervisor); - return new FileDownloader(connectionPool, supervisor, downloadDirectory, Duration.ofSeconds(300)); + return new FileDownloader(connectionPool, supervisor, downloadDirectory, Duration.ofSeconds(60)); } @Override diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java index 52c11ed0e93..2f0d8b4065d 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java @@ -79,6 +79,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; import java.util.logging.Level; import java.util.logging.Logger; @@ -369,7 +370,7 @@ public class SessionRepository { return session; } - public int deleteExpiredRemoteSessions(Clock clock) { + public int deleteExpiredRemoteSessions(Clock clock, Predicate<Session> sessionIsActiveForApplication) { Duration expiryTime = Duration.ofSeconds(expiryTimeFlag.value()); List<Long> remoteSessionsFromZooKeeper = getRemoteSessionsFromZooKeeper(); log.log(Level.FINE, () -> "Remote sessions for tenant " + tenantName + ": " + remoteSessionsFromZooKeeper); @@ -377,11 +378,11 @@ public class SessionRepository { int deleted = 0; // Avoid deleting too many in one run int deleteMax = (int) Math.min(1000, Math.max(50, remoteSessionsFromZooKeeper.size() * 0.05)); - for (long sessionId : remoteSessionsFromZooKeeper) { + for (Long sessionId : remoteSessionsFromZooKeeper) { Session session = remoteSessionCache.get(sessionId); if (session == null) session = new RemoteSession(tenantName, sessionId, createSessionZooKeeperClient(sessionId)); - if (session.getStatus() == Session.Status.ACTIVATE) continue; + if (session.getStatus() == Session.Status.ACTIVATE && sessionIsActiveForApplication.test(session)) continue; if (sessionHasExpired(session.getCreateTime(), expiryTime, clock)) { log.log(Level.FINE, () -> "Remote session " + sessionId + " for " + tenantName + " has expired, deleting it"); deleteRemoteSessionFromZooKeeper(session); @@ -616,7 +617,7 @@ public class SessionRepository { // ---------------- Common stuff ---------------------------------------------------------------- - public void deleteExpiredSessions(Map<ApplicationId, Long> activeSessions) { + public void deleteExpiredSessions(Predicate<Session> sessionIsActiveForApplication) { log.log(Level.FINE, () -> "Deleting expired local sessions for tenant '" + tenantName + "'"); Set<Long> sessionIdsToDelete = new HashSet<>(); Set<Long> newSessions = findNewSessionsInFileSystem(); @@ -650,8 +651,7 @@ public class SessionRepository { Optional<ApplicationId> applicationId = session.getOptionalApplicationId(); if (applicationId.isEmpty()) continue; - Long activeSession = activeSessions.get(applicationId.get()); - if (activeSession == null || activeSession != sessionId) { + if ( ! sessionIsActiveForApplication.test(session)) { sessionIdsToDelete.add(sessionId); log.log(Level.FINE, () -> "Will delete inactive session " + sessionId + " created " + createTime + " for '" + applicationId + "'"); diff --git a/configserver/src/main/resources/configserver-app/services.xml b/configserver/src/main/resources/configserver-app/services.xml index 2e666089152..069b7ffc496 100644 --- a/configserver/src/main/resources/configserver-app/services.xml +++ b/configserver/src/main/resources/configserver-app/services.xml @@ -111,6 +111,7 @@ </handler> <handler id='com.yahoo.vespa.config.server.http.v2.ApplicationApiHandler' bundle='configserver'> <binding>http://*/application/v2/tenant/*/prepareandactivate</binding> + <binding>http://*/application/v2/tenant/*/prepareandactivate/*</binding> </handler> <handler id='com.yahoo.vespa.config.server.http.v2.SessionContentHandler' bundle='configserver'> <binding>http://*/application/v2/tenant/*/session/*/content/*</binding> diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/MockProvisioner.java b/configserver/src/test/java/com/yahoo/vespa/config/server/MockProvisioner.java index d93ee19085a..0632da173af 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/MockProvisioner.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/MockProvisioner.java @@ -23,6 +23,7 @@ import java.util.List; public class MockProvisioner implements Provisioner { private boolean transientFailureOnPrepare = false; + private RuntimeException activationFailure = null; private HostProvisioner hostProvisioner = null; public MockProvisioner hostProvisioner(HostProvisioner hostProvisioner) { @@ -35,19 +36,29 @@ public class MockProvisioner implements Provisioner { return this; } + public MockProvisioner activationFailure(RuntimeException activationFailure) { + this.activationFailure = activationFailure; + return this; + } + @Override public List<HostSpec> prepare(ApplicationId applicationId, ClusterSpec cluster, Capacity capacity, ProvisionLogger logger) { - if (hostProvisioner != null) { - return hostProvisioner.prepare(cluster, capacity, logger); - } if (transientFailureOnPrepare) { throw new LoadBalancerServiceException("Unable to create load balancer", new Exception("some internal exception")); } + if (hostProvisioner != null) { + return hostProvisioner.prepare(cluster, capacity, logger); + } throw new UnsupportedOperationException("This mock does not support prepare"); } @Override public void activate(Collection<HostSpec> hosts, ActivationContext context, ApplicationTransaction transaction) { + if (activationFailure != null) { + RuntimeException toThrow = activationFailure; + activationFailure = null; + throw toThrow; + } } @Override diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/filedistribution/FileServerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/filedistribution/FileServerTest.java index 891284a3a0e..e0a58888109 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/filedistribution/FileServerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/filedistribution/FileServerTest.java @@ -206,7 +206,7 @@ public class FileServerTest { super(FileDownloader.emptyConnectionPool(), new Supervisor(new Transport("mock")).setDropEmptyBuffers(true), downloadDirectory, - Duration.ofMillis(100), + Duration.ofMillis(1000), Duration.ofMillis(100)); } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java new file mode 100644 index 00000000000..d4aa0676c4f --- /dev/null +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java @@ -0,0 +1,338 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.config.server.http.v2; + +import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.config.model.provision.InMemoryProvisioner; +import com.yahoo.config.provision.ApplicationLockException; +import com.yahoo.config.provision.ParentHostUnavailableException; +import com.yahoo.config.provision.TenantName; +import com.yahoo.config.provision.Zone; +import com.yahoo.container.jdisc.HttpResponse; +import com.yahoo.container.jdisc.ThreadedHttpRequestHandler.Context; +import com.yahoo.jdisc.http.HttpRequest.Method; +import com.yahoo.slime.SlimeUtils; +import com.yahoo.test.ManualClock; +import com.yahoo.vespa.config.server.ApplicationRepository; +import com.yahoo.vespa.config.server.MockProvisioner; +import com.yahoo.vespa.config.server.application.OrchestratorMock; +import com.yahoo.vespa.config.server.provision.HostProvisionerProvider; +import com.yahoo.vespa.config.server.tenant.TenantRepository; +import com.yahoo.vespa.config.server.tenant.TestTenantRepository; +import com.yahoo.vespa.curator.Curator; +import com.yahoo.vespa.curator.mock.MockCurator; +import org.apache.hc.client5.http.entity.mime.MultipartEntityBuilder; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.HttpEntity; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Path; +import java.time.Clock; +import java.util.Arrays; +import java.util.Map; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; + +import static com.yahoo.yolean.Exceptions.uncheck; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * @author jonmv + */ +class ApplicationApiHandlerTest { + + private static final TenantName tenant = TenantName.from("test"); + private static final Map<String, String> appPackage = Map.of("services.xml", + """ + <services version='1.0'> + <container id='jdisc' version='1.0'> + <nodes count='2' /> + </container> + </services> + """, + + "deployment.xml", + """ + <deployment version='1.0' /> + """); + static final String minimalPrepareParams = """ + { + "containerEndpoints": [ + { + "clusterId": "jdisc", + "scope": "zone", + "names": ["zone.endpoint"], + "routingMethod": "exclusive", + "authMethod": "mtls" + } + ] + } + """; + + private final Curator curator = new MockCurator(); + private ApplicationRepository applicationRepository; + + private MockProvisioner provisioner; + private ConfigserverConfig configserverConfig; + private TenantRepository tenantRepository; + private ApplicationApiHandler handler; + + @TempDir + public Path dbDir, defsDir, refsDir; + + @BeforeEach + public void setupRepo() throws IOException { + configserverConfig = new ConfigserverConfig.Builder() + .hostedVespa(true) + .configServerDBDir(dbDir.toString()) + .configDefinitionsDir(defsDir.toString()) + .fileReferencesDir(refsDir.toString()) + .build(); + Clock clock = new ManualClock(); + provisioner = new MockProvisioner().hostProvisioner(new InMemoryProvisioner(4, false)); + tenantRepository = new TestTenantRepository.Builder() + .withConfigserverConfig(configserverConfig) + .withCurator(curator) + .withHostProvisionerProvider(HostProvisionerProvider.withProvisioner(provisioner, configserverConfig)) + .build(); + tenantRepository.addTenant(tenant); + applicationRepository = new ApplicationRepository.Builder() + .withTenantRepository(tenantRepository) + .withOrchestrator(new OrchestratorMock()) + .withClock(clock) + .withConfigserverConfig(configserverConfig) + .build(); + handler = new ApplicationApiHandler(new Context(Runnable::run, null), + applicationRepository, + configserverConfig, + Zone.defaultZone()); + } + + private HttpResponse put(long sessionId, Map<String, String> parameters) throws IOException { + var request = com.yahoo.container.jdisc.HttpRequest.createTestRequest("http://host:123/application/v2/tenant/" + tenant + "/prepareandactivate/" + sessionId, + Method.PUT, + InputStream.nullInputStream(), + parameters); + return handler.handle(request); + } + + private HttpResponse post(String json, byte[] appZip, Map<String, String> parameters) throws IOException { + HttpEntity entity = MultipartEntityBuilder.create() + .addTextBody("prepareParams", json, ContentType.APPLICATION_JSON) + .addBinaryBody("applicationPackage", appZip, ContentType.create("application/zip"), "applicationZip") + .build(); + var request = com.yahoo.container.jdisc.HttpRequest.createTestRequest("http://host:123/application/v2/tenant/" + tenant + "/prepareandactivate", + Method.POST, + entity.getContent(), + parameters); + request.getJDiscRequest().headers().add("Content-Type", entity.getContentType()); + return handler.handle(request); + } + + private static byte[] zip(Map<String, String> files) throws IOException { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + try (ZipOutputStream zip = new ZipOutputStream(buffer)) { + files.forEach((name, content) -> uncheck(() -> { + zip.putNextEntry(new ZipEntry(name)); + zip.write(content.getBytes(UTF_8)); + })); + } + return buffer.toByteArray(); + } + + private static void verifyResponse(HttpResponse response, int expectedStatusCode, String expectedBody) throws IOException { + String body = new ByteArrayOutputStream() {{ response.render(this); }}.toString(UTF_8); + assertEquals(expectedStatusCode, response.getStatus(), "Status code should match. Response was:\n" + body); + assertEquals(SlimeUtils.toJson(SlimeUtils.jsonToSlimeOrThrow(expectedBody).get(), false), + SlimeUtils.toJson(SlimeUtils.jsonToSlimeOrThrow(body).get(), false)); + } + + @Test + void testMinimalDeployment() throws Exception { + verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), + 200, + """ + { + "log": [ ], + "message": "Session 2 for tenant 'test' prepared and activated.", + "session-id": "2", + "activated": true, + "tenant": "test", + "url": "http://host:123/application/v2/tenant/test/application/default/environment/prod/region/default/instance/default", + "configChangeActions": { + "restart": [ ], + "refeed": [ ], + "reindex": [ ] + } + } + """); + } + + @Test + void testBadZipDeployment() throws Exception { + verifyResponse(post("{ }", Arrays.copyOf(zip(appPackage), 13), Map.of()), + 400, + """ + { + "error-code": "BAD_REQUEST", + "message": "Error preprocessing application package for test.default, session 2: services.xml does not exist in application package" + } + """); + } + + @Test + void testPrepareFailure() throws Exception { + provisioner.transientFailureOnPrepare(); + verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), + 409, + """ + { + "error-code": "LOAD_BALANCER_NOT_READY", + "message": "Unable to create load balancer: some internal exception" + } + """); + } + + @Test + void testActivateInvalidSession() throws Exception { + verifyResponse(put(2, Map.of()), + 404, + """ + { + "error-code": "NOT_FOUND", + "message": "Local session 2 for 'test' was not found" + } + """); + } + + @Test + void testActivationFailuresAndRetries() throws Exception { + // Prepare session 2, and activate it successfully. + verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), + 200, + """ + { + "log": [ ], + "message": "Session 2 for tenant 'test' prepared and activated.", + "session-id": "2", + "activated": true, + "tenant": "test", + "url": "http://host:123/application/v2/tenant/test/application/default/environment/prod/region/default/instance/default", + "configChangeActions": { + "restart": [ ], + "refeed": [ ], + "reindex": [ ] + } + } + """); + + // Prepare session 3, but fail on hosts; this session will be activated later. + provisioner.activationFailure(new ParentHostUnavailableException("host still booting")); + verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), + 200, + """ + { + "log": [ ], + "message": "Session 3 for tenant 'test' prepared, but activation failed: host still booting", + "session-id": "3", + "activated": false, + "tenant": "test", + "url": "http://host:123/application/v2/tenant/test/application/default/environment/prod/region/default/instance/default", + "configChangeActions": { + "restart": [ ], + "refeed": [ ], + "reindex": [ ] + } + } + """); + + // Prepare session 4, but fail on lock; this session will become outdated later. + provisioner.activationFailure(new ApplicationLockException("lock timeout")); + verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), + 200, + """ + { + "log": [ ], + "message": "Session 4 for tenant 'test' prepared, but activation failed: lock timeout", + "session-id": "4", + "activated": false, + "tenant": "test", + "url": "http://host:123/application/v2/tenant/test/application/default/environment/prod/region/default/instance/default", + "configChangeActions": { + "restart": [ ], + "refeed": [ ], + "reindex": [ ] + } + } + """); + + // Prepare session 4, but fail with some other exception, which we won't retry. + provisioner.activationFailure(new RuntimeException("some other exception")); + verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), + 500, + """ + { + "error-code": "INTERNAL_SERVER_ERROR", + "message": "some other exception" + } + """); + + // Retry only activation of session 3, but fail again with hosts. + provisioner.activationFailure(new ParentHostUnavailableException("host still booting")); + verifyResponse(put(3, Map.of()), + 409, + """ + { + "error-code": "PARENT_HOST_NOT_READY", + "message": "host still booting" + } + """); + + // Retry only activation of session 3, but fail again with lock. + provisioner.activationFailure(new ApplicationLockException("lock timeout")); + verifyResponse(put(3, Map.of()), + 500, + """ + { + "error-code": "APPLICATION_LOCK_FAILURE", + "message": "lock timeout" + } + """); + + // Retry only activation of session 3, and succeed! + provisioner.activationFailure(null); + verifyResponse(put(3, Map.of()), + 200, + """ + { + "message": "Session 3 for test.default.default activated" + } + """); + + // Retry only activation of session 4, but fail because it is now based on an outdated session. + verifyResponse(put(4, Map.of()), + 409, + """ + { + "error-code": "ACTIVATION_CONFLICT", + "message": "app:test.default.default Cannot activate session 4 because the currently active session (3) has changed since session 4 was created (was 2 at creation time)" + } + """); + + // Retry activation of session 3 again, and fail. + verifyResponse(put(3, Map.of()), + 400, + """ + { + "error-code": "BAD_REQUEST", + "message": "app:test.default.default Session 3 is already active" + } + """); + } + +} diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index aedccbee46b..2304743873f 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -1154,6 +1154,7 @@ "public void addItem(int, com.yahoo.prelude.query.Item)", "public com.yahoo.prelude.query.Item setItem(int, com.yahoo.prelude.query.Item)", "public java.util.Optional extractSingleChild()", + "public void setWeight(int)", "public com.yahoo.prelude.query.WordItem getWordItem(int)", "public com.yahoo.prelude.query.BlockItem getBlockItem(int)", "protected void encodeThis(java.nio.ByteBuffer)", diff --git a/container-search/src/main/java/com/yahoo/fs4/MapEncoder.java b/container-search/src/main/java/com/yahoo/fs4/MapEncoder.java index 84b2b482403..4f31db0fc86 100644 --- a/container-search/src/main/java/com/yahoo/fs4/MapEncoder.java +++ b/container-search/src/main/java/com/yahoo/fs4/MapEncoder.java @@ -20,7 +20,7 @@ public class MapEncoder { // TODO: Time to refactor - private static byte [] getUtf8(Object value) { + private static byte[] getUtf8(Object value) { if (value == null) { return Utf8.toBytes(""); } else if (value instanceof Tensor) { @@ -62,7 +62,7 @@ public class MapEncoder { public static int encodeMap(String mapName, Map<String,?> map, ByteBuffer buffer) { if (map.isEmpty()) return 0; - byte [] utf8 = Utf8.toBytes(mapName); + byte[] utf8 = Utf8.toBytes(mapName); buffer.putInt(utf8.length); buffer.put(utf8); buffer.putInt(map.size()); diff --git a/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java b/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java index 441c4326355..88cc7ad7b2d 100644 --- a/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java @@ -172,6 +172,21 @@ public class ClusterSearcher extends Searcher { } @Override + public Result search(Query query, Execution execution) { + validateQueryTimeout(query); + validateQueryCache(query); + Searcher searcher = server; + if (searcher == null) { + return new Result(query, ErrorMessage.createNoBackendsInService("Could not search")); + } + if (query.getTimeLeft() <= 0) { + return new Result(query, ErrorMessage.createTimeout("No time left for searching")); + } + + return doSearch(searcher, query, execution); + } + + @Override public void fill(com.yahoo.search.Result result, String summaryClass, Execution execution) { Query query = result.getQuery(); @@ -192,21 +207,6 @@ public class ClusterSearcher extends Searcher { } } - @Override - public Result search(Query query, Execution execution) { - validateQueryTimeout(query); - validateQueryCache(query); - Searcher searcher = server; - if (searcher == null) { - return new Result(query, ErrorMessage.createNoBackendsInService("Could not search")); - } - if (query.getTimeLeft() <= 0) { - return new Result(query, ErrorMessage.createTimeout("No time left for searching")); - } - - return doSearch(searcher, query, execution); - } - private void validateQueryTimeout(Query query) { if (query.getTimeout() <= maxQueryTimeout) return; diff --git a/container-search/src/main/java/com/yahoo/prelude/query/PhraseItem.java b/container-search/src/main/java/com/yahoo/prelude/query/PhraseItem.java index 755e8de5e4f..ff811c97ba4 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/PhraseItem.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/PhraseItem.java @@ -153,14 +153,29 @@ public class PhraseItem extends CompositeIndexedItem { private void addIndexedItem(int index, IndexedItem word) { word.setIndexName(this.getIndexName()); + if (word instanceof Item item) { + item.setWeight(this.getWeight()); + } super.addItem(index, (Item) word); } private Item setIndexedItem(int index, IndexedItem word) { word.setIndexName(this.getIndexName()); + if (word instanceof Item item) { + item.setWeight(this.getWeight()); + } return super.setItem(index, (Item) word); } + @Override + public void setWeight(int weight) { + super.setWeight(weight); + for (Iterator<Item> i = getItemIterator(); i.hasNext();) { + Item word = i.next(); + word.setWeight(weight); + } + } + /** * Returns a subitem as a word item * diff --git a/container-search/src/main/java/com/yahoo/prelude/query/parser/AllParser.java b/container-search/src/main/java/com/yahoo/prelude/query/parser/AllParser.java index 0010291de66..43bd175b348 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/parser/AllParser.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/parser/AllParser.java @@ -11,6 +11,7 @@ import com.yahoo.prelude.query.OrItem; import com.yahoo.prelude.query.PhraseItem; import com.yahoo.prelude.query.QueryCanonicalizer; import com.yahoo.prelude.query.RankItem; +import com.yahoo.prelude.query.TrueItem; import com.yahoo.prelude.query.WeakAndItem; import com.yahoo.search.query.QueryTree; import com.yahoo.search.query.parser.ParserEnvironment; @@ -79,8 +80,8 @@ public class AllParser extends SimpleParser { // Combine the items Item topLevel = and; - if (not != null && topLevel != null) { - not.setPositiveItem(topLevel); + if (not != null) { + not.setPositiveItem(topLevel != null ? topLevel : new TrueItem()); topLevel = not; } @@ -130,6 +131,7 @@ public class AllParser extends SimpleParser { if ( ! tokens.skip(MINUS)) return null; if (tokens.currentIsNoIgnore(SPACE)) return null; var itemAndExplicitIndex = indexableItem(); + item = itemAndExplicitIndex.getFirst(); boolean explicitIndex = itemAndExplicitIndex.getSecond(); if (item == null) { diff --git a/container-search/src/main/java/com/yahoo/prelude/query/parser/AnyParser.java b/container-search/src/main/java/com/yahoo/prelude/query/parser/AnyParser.java index efc804fcf1f..1bbc21768b5 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/parser/AnyParser.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/parser/AnyParser.java @@ -12,6 +12,7 @@ import com.yahoo.prelude.query.OrItem; import com.yahoo.prelude.query.PhraseItem; import com.yahoo.prelude.query.RankItem; import com.yahoo.prelude.query.TermItem; +import com.yahoo.prelude.query.TrueItem; import com.yahoo.search.query.parser.ParserEnvironment; import java.util.Iterator; @@ -106,9 +107,8 @@ public class AnyParser extends SimpleParser { } return rank; } else if ((topLevelItem instanceof RankItem) - && (item instanceof RankItem) + && (item instanceof RankItem itemAsRank) && (((RankItem) item).getItem(0) instanceof OrItem)) { - RankItem itemAsRank = (RankItem) item; OrItem or = (OrItem) itemAsRank.getItem(0); ((RankItem) topLevelItem).addItem(0, or); @@ -139,8 +139,10 @@ public class AnyParser extends SimpleParser { if (root instanceof PhraseItem) { root.setFilter(true); } - for (Iterator<Item> i = ((CompositeItem) root).getItemIterator(); i.hasNext();) { - markAllTermsAsFilters(i.next()); + if (root instanceof CompositeItem composite) { + for (Iterator<Item> i = composite.getItemIterator(); i.hasNext(); ) { + markAllTermsAsFilters(i.next()); + } } } } @@ -206,8 +208,7 @@ public class AnyParser extends SimpleParser { return root; } - if (root instanceof RankItem) { - RankItem rootAsRank = (RankItem) root; + if (root instanceof RankItem rootAsRank) { Item firstChild = rootAsRank.getItem(0); if (firstChild instanceof NotItem) { @@ -228,7 +229,6 @@ public class AnyParser extends SimpleParser { } NotItem not = new NotItem(); - not.addPositiveItem(root); not.addNegativeItem(item); return not; diff --git a/container-search/src/main/java/com/yahoo/prelude/query/parser/SimpleParser.java b/container-search/src/main/java/com/yahoo/prelude/query/parser/SimpleParser.java index deab2be9d00..ea0cd2312a6 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/parser/SimpleParser.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/parser/SimpleParser.java @@ -130,14 +130,13 @@ abstract class SimpleParser extends StructuredParser { } } if (not != null && not.getPositiveItem() instanceof TrueItem) { - // Incomplete not, only negatives - - + // Incomplete not, only negatives - simplify when possible if (topLevelItem != null && topLevelItem != not) { // => neutral rank items becomes implicit positives not.addPositiveItem(getItemAsPositiveItem(topLevelItem, not)); return not; - } else { // Only negatives - ignore them - return null; + } else { + return not; } } if (topLevelItem != null) { diff --git a/container-search/src/main/java/com/yahoo/prelude/query/parser/WebParser.java b/container-search/src/main/java/com/yahoo/prelude/query/parser/WebParser.java index 06ea583c53f..75396a8714f 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/parser/WebParser.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/parser/WebParser.java @@ -6,6 +6,7 @@ import com.yahoo.prelude.query.CompositeItem; import com.yahoo.prelude.query.Item; import com.yahoo.prelude.query.NotItem; import com.yahoo.prelude.query.OrItem; +import com.yahoo.prelude.query.TrueItem; import com.yahoo.prelude.query.WordItem; import com.yahoo.search.query.parser.ParserEnvironment; @@ -69,8 +70,8 @@ public class WebParser extends AllParser { if (or != null) topLevel = or; - if (not != null && topLevel != null) { - not.setPositiveItem(topLevel); + if (not != null) { + not.setPositiveItem(topLevel != null ? topLevel : new TrueItem()); topLevel = not; } diff --git a/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java b/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java index 663704b1e0b..736dc7d7f39 100644 --- a/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java @@ -68,8 +68,8 @@ public class ValidateSortingSearcher extends Searcher { @Override public Result search(Query query, Execution execution) { + ErrorMessage e = validate(query); if (indexingMode != QrSearchersConfig.Searchcluster.Indexingmode.STREAMING) { - ErrorMessage e = validate(query); if (e != null) { Result r = new Result(query); r.hits().addError(e); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java index bef766e7ef9..70f6e405a92 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java @@ -15,6 +15,7 @@ public class ConversionContext { private final String destination; private final CompiledQueryProfileRegistry registry; private final Map<String, Embedder> embedders; + private final Map<String, String> contextValues; private final Language language; public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder, @@ -30,6 +31,7 @@ public class ConversionContext { this.embedders = embedders; this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language")) : Language.UNKNOWN; + this.contextValues = context; } /** Returns the local name of the field which will receive the converted value (or null when this is empty) */ @@ -44,6 +46,9 @@ public class ConversionContext { /** Returns the language, which is never null but may be UNKNOWN */ Language language() { return language; } + /** Returns a read-only map of context key-values which can be looked up during conversion. */ + Map<String,String> contextValues() { return contextValues; } + /** Returns an empty context */ public static ConversionContext empty() { return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of()); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index cfadd79de8f..e16f8e7b0cd 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -48,7 +48,8 @@ public class TensorFieldType extends FieldType { @Override public Object convertFrom(Object o, ConversionContext context) { if (o instanceof SubstituteString) return new SubstituteStringTensor((SubstituteString) o, type); - return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, context.language()); + return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, + context.language(), context.contextValues()); } public static TensorFieldType fromTypeString(String s) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java index c9f935e5f52..25a5c277dce 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java @@ -44,7 +44,8 @@ public class RankProfileInputProperties extends Properties { value = tensorConverter.convertTo(expectedType, name.last(), value, - query.getModel().getLanguage()); + query.getModel().getLanguage(), + context); } } catch (IllegalArgumentException e) { diff --git a/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java b/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java index 4ac5375807b..fd0b6543f28 100644 --- a/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java @@ -38,16 +38,12 @@ public class RankProperties implements Cloneable { /** Adds a property by full name to a value */ public void put(String name, Object value) { - List<Object> list = properties.get(name); - if (list == null) { - list = new ArrayList<>(); - properties.put(name, list); - } + List<Object> list = properties.computeIfAbsent(name, k -> new ArrayList<>()); list.add(value); } /** - * Returns a read-only list of properties properties by full name. + * Returns a read-only list of properties by full name. * If this is not set, null is returned. If this is explicitly set to * have no values, and empty list is returned. */ diff --git a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java index 6da53ae699c..94f92c7fd48 100644 --- a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java +++ b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java @@ -19,7 +19,8 @@ import java.util.regex.Pattern; */ public class TensorConverter { - private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); + private static final Pattern embedderArgumentAndQuotedTextRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*([\"'].*[\"'])"); + private static final Pattern embedderArgumentAndReferenceRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*(@.*)"); private final Map<String, Embedder> embedders; @@ -27,8 +28,9 @@ public class TensorConverter { this.embedders = embedders; } - public Tensor convertTo(TensorType type, String key, Object value, Language language) { - var context = new Embedder.Context(key).setLanguage(language); + public Tensor convertTo(TensorType type, String key, Object value, Language language, + Map<String, String> contextValues) { + var context = new Embedder.Context(key).setLanguage(language).setContextValues(contextValues); Tensor tensor = toTensor(type, value, context); if (tensor == null) return null; if (! tensor.type().isAssignableTo(type)) @@ -55,16 +57,16 @@ public class TensorConverter { String embedderId; // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") - Matcher matcher = embedderArgumentRegexp.matcher(argument); - if (matcher.matches()) { + Matcher matcher; + if (( matcher = embedderArgumentAndQuotedTextRegexp.matcher(argument)).matches()) { embedderId = matcher.group(1); + embedder = requireEmbedder(embedderId); argument = matcher.group(2); - if ( ! embedders.containsKey(embedderId)) { - throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + - "Valid embedders are " + validEmbedders(embedders)); - } - embedder = embedders.get(embedderId); - } else if (embedders.size() == 0) { + } else if (( matcher = embedderArgumentAndReferenceRegexp.matcher(argument)).matches()) { + embedderId = matcher.group(1); + embedder = requireEmbedder(embedderId); + argument = matcher.group(2); + } else if (embedders.isEmpty()) { throw new IllegalStateException("No embedders provided"); // should never happen } else if (embedders.size() > 1) { throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + @@ -74,19 +76,35 @@ public class TensorConverter { embedderId = entry.getKey(); embedder = entry.getValue(); } - return embedder.embed(removeQuotes(argument), embedderContext.copy().setEmbedderId(embedderId), type); + return embedder.embed(resolve(argument, embedderContext), embedderContext.copy().setEmbedderId(embedderId), type); } - private static String removeQuotes(String s) { - if (s.startsWith("'") && s.endsWith("'")) { + private Embedder requireEmbedder(String embedderId) { + if ( ! embedders.containsKey(embedderId)) + throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(embedders)); + return embedders.get(embedderId); + } + + private static String resolve(String s, Embedder.Context embedderContext) { + if (s.startsWith("'") && s.endsWith("'")) return s.substring(1, s.length() - 1); - } - if (s.startsWith("\"") && s.endsWith("\"")) { + if (s.startsWith("\"") && s.endsWith("\"")) return s.substring(1, s.length() - 1); - } + if (s.startsWith("@")) + return resolveReference(s, embedderContext); return s; } + private static String resolveReference(String s, Embedder.Context embedderContext) { + String referenceKey = s.substring(1); + String referencedValue = embedderContext.getContextValues().get(referenceKey); + if (referencedValue == null) + throw new IllegalArgumentException("Could not resolve query parameter reference '" + referenceKey + + "' used in an embed() argument"); + return referencedValue; + } + private static String validEmbedders(Map<String, Embedder> embedders) { List<String> embedderIds = new ArrayList<>(); embedders.forEach((key, value) -> embedderIds.add(key)); diff --git a/container-search/src/test/java/com/yahoo/prelude/query/parser/test/ParseTestCase.java b/container-search/src/test/java/com/yahoo/prelude/query/parser/test/ParseTestCase.java index 4383be184fa..156c34e5005 100644 --- a/container-search/src/test/java/com/yahoo/prelude/query/parser/test/ParseTestCase.java +++ b/container-search/src/test/java/com/yahoo/prelude/query/parser/test/ParseTestCase.java @@ -21,6 +21,7 @@ import com.yahoo.prelude.query.RankItem; import com.yahoo.prelude.query.SubstringItem; import com.yahoo.prelude.query.SuffixItem; import com.yahoo.prelude.query.TaggableItem; +import com.yahoo.prelude.query.TrueItem; import com.yahoo.prelude.query.WordItem; import com.yahoo.language.process.SpecialTokens; import com.yahoo.prelude.query.parser.TestLinguistics; @@ -262,17 +263,28 @@ public class ParseTestCase { @Test void testNotOnly() { - tester.assertParsed(null, "-foobar", Query.Type.ALL); + Item item = tester.assertParsed("-foobar", "-foobar", Query.Type.ALL); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem) item).getPositiveItem() instanceof TrueItem); } @Test - void testMultipleNotsOnlt() { - tester.assertParsed(null, "-foo -bar -foobar", Query.Type.ALL); + void testNotOnlyAny() { + Item item = tester.assertParsed("-foobar", "-foobar", Query.Type.ANY); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem) item).getPositiveItem() instanceof TrueItem); + } + + @Test + void testMultipleNotsOnly() { + Item item = tester.assertParsed("-foo -bar -foobar", "-foo -bar -foobar", Query.Type.ALL); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem) item).getPositiveItem() instanceof TrueItem); } @Test void testOnlyNotComposite() { - tester.assertParsed(null, "-(foo bar baz)", Query.Type.ALL); + tester.assertParsed("-(AND foo bar baz)", "-(foo bar baz)", Query.Type.ALL); } @Test @@ -391,7 +403,7 @@ public class ParseTestCase { @Test void testMinusAndPluses() { - tester.assertParsed(null, "--test+-if", Query.Type.ANY); + tester.assertParsed("-(AND test if)", "--test+-if", Query.Type.ANY); } @Test @@ -1305,7 +1317,9 @@ public class ParseTestCase { @Test void testNotFilterEmptyQuery() { - tester.assertParsed(null, "", "-foo", Query.Type.ANY); + Item item = tester.assertParsed("-|foo", "", "-foo", Query.Type.ANY); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem) item).getPositiveItem() instanceof TrueItem); } @Test @@ -1380,7 +1394,7 @@ public class ParseTestCase { @Test void testMultitermNotFilterEmptyQuery() { - tester.assertParsed(null, "", "-foo -foz", Query.Type.ANY); + tester.assertParsed("-|foo -|foz", "", "-foo -foz", Query.Type.ANY); } @Test @@ -2320,17 +2334,19 @@ public class ParseTestCase { @Test void testNotOnlyWeb() { - tester.assertParsed(null, "-foobar", Query.Type.WEB); + Item item = tester.assertParsed("-foobar", "-foobar", Query.Type.WEB); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem)item).getPositiveItem() instanceof TrueItem); } @Test void testMultipleNotsOnltWeb() { - tester.assertParsed(null, "-foo -bar -foobar", Query.Type.WEB); + tester.assertParsed("-foo -bar -foobar", "-foo -bar -foobar", Query.Type.WEB); } @Test void testOnlyNotCompositeWeb() { - tester.assertParsed(null, "-(foo bar baz)", Query.Type.WEB); + tester.assertParsed("-(AND foo bar baz)", "-(foo bar baz)", Query.Type.WEB); } @Test diff --git a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java index 90e21e5f3b0..429b8d1c6cb 100644 --- a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java +++ b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java @@ -185,6 +185,21 @@ public class RankProfileInputTest { assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders, Language.UNKNOWN.languageCode()); } + @Test + void testUnembeddedTensorRankFeatureInRequestReferencedFromAParameter() { + String text = "text to embed into a tensor"; + Tensor embedding1 = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]"); + + Map<String, Embedder> embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1) + ); + assertEmbedQuery("embed(@param1)", embedding1, embedders, null, text); + assertEmbedQuery("embed(emb1, @param1)", embedding1, embedders, null, text); + assertEmbedQueryFails("embed(emb1, @noSuchParam)", embedding1, embedders, + "Could not resolve query parameter reference 'noSuchParam' " + + "used in an embed() argument"); + } + private Query createTensor1Query(String tensorString, String profile, String additionalParams) { return new Query.Builder() .setSchemaInfo(createSchemaInfo()) @@ -202,18 +217,24 @@ public class RankProfileInputTest { } private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders) { - assertEmbedQuery(embed, expected, embedders, null); + assertEmbedQuery(embed, expected, embedders, null, null); } private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language) { + assertEmbedQuery(embed, expected, embedders, language, null); + } + private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language, String param1Value) { String languageParam = language == null ? "" : "&language=" + language; + String param1 = param1Value == null ? "" : "¶m1=" + urlEncode(param1Value); + String destination = "query(myTensor4)"; Query query = new Query.Builder().setRequest(HttpRequest.createTestRequest( "?" + urlEncode("ranking.features." + destination) + "=" + urlEncode(embed) + "&ranking=commonProfile" + - languageParam, + languageParam + + param1, com.yahoo.jdisc.http.HttpRequest.Method.GET)) .setSchemaInfo(createSchemaInfo()) .setQueryProfile(createQueryProfile()) @@ -230,7 +251,7 @@ public class RankProfileInputTest { if (t.getMessage().equals(errMsg)) return; t = t.getCause(); } - fail("Error '" + errMsg + "' not thrown"); + fail("Exception with message '" + errMsg + "' not thrown"); } private CompiledQueryProfile createQueryProfile() { diff --git a/dependency-versions/pom.xml b/dependency-versions/pom.xml index d7dcdd2c769..ddd29b0971c 100644 --- a/dependency-versions/pom.xml +++ b/dependency-versions/pom.xml @@ -58,15 +58,15 @@ <antlr4.vespa.version>4.13.1</antlr4.vespa.version> <apache.httpclient.vespa.version>4.5.14</apache.httpclient.vespa.version> <apache.httpcore.vespa.version>4.4.16</apache.httpcore.vespa.version> - <apache.httpclient5.vespa.version>5.3</apache.httpclient5.vespa.version> + <apache.httpclient5.vespa.version>5.3.1</apache.httpclient5.vespa.version> <apache.httpcore5.vespa.version>5.2.4</apache.httpcore5.vespa.version> <apiguardian.vespa.version>1.1.2</apiguardian.vespa.version> <asm.vespa.version>9.6</asm.vespa.version> - <assertj.vespa.version>3.25.1</assertj.vespa.version> + <assertj.vespa.version>3.25.2</assertj.vespa.version> <!-- Athenz dependencies. Make sure these dependencies match those in Vespa's internal repositories --> - <athenz.vespa.version>1.11.49</athenz.vespa.version> - <aws-sdk.vespa.version>1.12.638</aws-sdk.vespa.version> + <aws-sdk.vespa.version>1.12.649</aws-sdk.vespa.version> + <athenz.vespa.version>1.11.51</athenz.vespa.version> <!-- Athenz END --> <!-- WARNING: If you change curator version, you also need to update @@ -90,7 +90,7 @@ <commons-compress.vespa.version>1.25.0</commons-compress.vespa.version> <commons-cli.vespa.version>1.6.0</commons-cli.vespa.version> <curator.vespa.version>5.6.0</curator.vespa.version> - <dropwizard.metrics.vespa.version>4.2.23</dropwizard.metrics.vespa.version> <!-- ZK 3.9.1 requires this --> + <dropwizard.metrics.vespa.version>4.2.25</dropwizard.metrics.vespa.version> <!-- ZK 3.9.1 requires this --> <eclipse-collections.vespa.version>11.1.0</eclipse-collections.vespa.version> <eclipse-sisu.vespa.version>0.9.0.M2</eclipse-sisu.vespa.version> <failureaccess.vespa.version>1.0.2</failureaccess.vespa.version> @@ -105,7 +105,7 @@ <java-jwt.vespa.version>4.4.0</java-jwt.vespa.version> <javax.annotation.vespa.version>1.2</javax.annotation.vespa.version> <jaxb.runtime.vespa.version>4.0.4</jaxb.runtime.vespa.version> - <jetty.vespa.version>11.0.19</jetty.vespa.version> + <jetty.vespa.version>11.0.20</jetty.vespa.version> <jetty-servlet-api.vespa.version>5.0.2</jetty-servlet-api.vespa.version> <jimfs.vespa.version>1.3.0</jimfs.vespa.version> <jna.vespa.version>5.14.0</jna.vespa.version> @@ -114,13 +114,13 @@ <junit.platform.vespa.version>1.10.1</junit.platform.vespa.version> <junit4.vespa.version>4.13.2</junit4.vespa.version> <luben.zstd.vespa.version>1.5.5-11</luben.zstd.vespa.version> - <lucene.vespa.version>9.9.1</lucene.vespa.version> + <lucene.vespa.version>9.9.2</lucene.vespa.version> <maven-archiver.vespa.version>3.6.1</maven-archiver.vespa.version> <maven-wagon.vespa.version>3.5.3</maven-wagon.vespa.version> <mimepull.vespa.version>1.10.0</mimepull.vespa.version> - <mockito.vespa.version>5.9.0</mockito.vespa.version> + <mockito.vespa.version>5.10.0</mockito.vespa.version> <mojo-executor.vespa.version>2.4.0</mojo-executor.vespa.version> - <netty.vespa.version>4.1.105.Final</netty.vespa.version> + <netty.vespa.version>4.1.106.Final</netty.vespa.version> <netty-tcnative.vespa.version>2.0.62.Final</netty-tcnative.vespa.version> <onnxruntime.vespa.version>1.16.3</onnxruntime.vespa.version> <opennlp.vespa.version>2.3.1</opennlp.vespa.version> @@ -130,7 +130,7 @@ <prometheus.client.vespa.version>0.16.0</prometheus.client.vespa.version> <plexus-interpolation.vespa.version>1.27</plexus-interpolation.vespa.version> <protobuf.vespa.version>3.25.2</protobuf.vespa.version> - <questdb.vespa.version>7.3.7</questdb.vespa.version> + <questdb.vespa.version>7.3.9</questdb.vespa.version> <spifly.vespa.version>1.3.7</spifly.vespa.version> <spotbugs.vespa.version>4.8.3</spotbugs.vespa.version> <!-- Must match major version in https://github.com/apache/zookeeper/blob/master/pom.xml --> <snappy.vespa.version>1.1.10.5</snappy.vespa.version> @@ -152,7 +152,7 @@ <surefire.vespa.tenant.version>${surefire.vespa.version}</surefire.vespa.tenant.version> <!-- Maven plugins --> - <clover-maven-plugin.vespa.version>4.5.1</clover-maven-plugin.vespa.version> + <clover-maven-plugin.vespa.version>4.5.2</clover-maven-plugin.vespa.version> <maven-antrun-plugin.vespa.version>3.1.0</maven-antrun-plugin.vespa.version> <maven-assembly-plugin.vespa.version>3.6.0</maven-assembly-plugin.vespa.version> <maven-bundle-plugin.vespa.version>5.1.9</maven-bundle-plugin.vespa.version> diff --git a/document/src/main/java/com/yahoo/document/json/JsonReader.java b/document/src/main/java/com/yahoo/document/json/JsonReader.java index 3e1743b8d45..b6cf8c6e18b 100644 --- a/document/src/main/java/com/yahoo/document/json/JsonReader.java +++ b/document/src/main/java/com/yahoo/document/json/JsonReader.java @@ -6,8 +6,10 @@ import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonToken; import com.yahoo.document.DocumentId; import com.yahoo.document.DocumentOperation; +import com.yahoo.document.DocumentPut; import com.yahoo.document.DocumentType; import com.yahoo.document.DocumentTypeManager; +import com.yahoo.document.DocumentUpdate; import com.yahoo.document.TestAndSetCondition; import com.yahoo.document.json.document.DocumentParser; import com.yahoo.document.json.readers.DocumentParseInfo; @@ -18,6 +20,9 @@ import java.io.InputStream; import java.util.Optional; import static com.yahoo.document.json.JsonReader.ReaderState.END_OF_FEED; +import static com.yahoo.document.json.document.DocumentParser.CONDITION; +import static com.yahoo.document.json.document.DocumentParser.CREATE_IF_NON_EXISTENT; +import static com.yahoo.document.json.document.DocumentParser.FIELDS; import static com.yahoo.document.json.readers.JsonParserHelpers.expectArrayStart; /** @@ -60,7 +65,7 @@ public class JsonReader { * @param docIdString document ID * @return the parsed document operation */ - public ParsedDocumentOperation readSingleDocument(DocumentOperationType operationType, String docIdString) { + ParsedDocumentOperation readSingleDocument(DocumentOperationType operationType, String docIdString) { DocumentId docId = new DocumentId(docIdString); DocumentParseInfo documentParseInfo; try { @@ -78,6 +83,79 @@ public class JsonReader { return operation; } + /** + * Reads a JSON which is expected to contain a single document operation, + * and where other parameters, like the document ID and operation type, are supplied by other means. + * + * @param operationType the type of operation (update or put) + * @param docIdString document ID + * @return the parsed document operation + */ + public ParsedDocumentOperation readSingleDocumentStreaming(DocumentOperationType operationType, String docIdString) { + try { + DocumentId docId = new DocumentId(docIdString); + DocumentParseInfo documentParseInfo = new DocumentParseInfo(); + documentParseInfo.documentId = docId; + documentParseInfo.operationType = operationType; + + if (JsonToken.START_OBJECT != parser.nextValue()) + throw new IllegalArgumentException("expected start of root object, got " + parser.currentToken()); + + Boolean create = null; + String condition = null; + ParsedDocumentOperation operation = null; + while (JsonToken.END_OBJECT != parser.nextValue()) { + switch (parser.getCurrentName()) { + case FIELDS -> { + documentParseInfo.fieldsBuffer = new LazyTokenBuffer(parser); + VespaJsonDocumentReader vespaJsonDocumentReader = new VespaJsonDocumentReader(typeManager.getIgnoreUndefinedFields()); + operation = vespaJsonDocumentReader.createDocumentOperation( + getDocumentTypeFromString(documentParseInfo.documentId.getDocType(), typeManager), documentParseInfo); + + if ( ! documentParseInfo.fieldsBuffer.isEmpty()) + throw new IllegalArgumentException("expected all content to be consumed by document parsing, but " + + documentParseInfo.fieldsBuffer.nesting() + " levels remain"); + + } + case CONDITION -> { + if ( ! JsonToken.VALUE_STRING.equals(parser.currentToken()) && ! JsonToken.VALUE_NULL.equals(parser.currentToken())) + throw new IllegalArgumentException("expected string value for condition, got " + parser.currentToken()); + + condition = parser.getValueAsString(); + } + case CREATE_IF_NON_EXISTENT -> { + create = parser.getBooleanValue(); // Throws if not boolean. + } + default -> { + // We ignore stray fields, but need to ensure structural balance in doing do. + if (parser.currentToken().isStructStart()) parser.skipChildren(); + } + } + } + + if (null != parser.nextToken()) + throw new IllegalArgumentException("expected end of input, got " + parser.currentToken()); + + if (null == operation) + throw new IllegalArgumentException("document is missing the required \"fields\" field"); + + if (null != create) { + switch (operationType) { + case PUT -> ((DocumentPut) operation.operation()).setCreateIfNonExistent(create); + case UPDATE -> ((DocumentUpdate) operation.operation()).setCreateIfNonExistent(create); + case REMOVE -> throw new IllegalArgumentException(CREATE_IF_NON_EXISTENT + " is not supported for remove operations"); + } + } + + operation.operation().setCondition(TestAndSetCondition.fromConditionString(Optional.ofNullable(condition))); + + return operation; + } + catch (IOException e) { + throw new IllegalArgumentException("failed parsing document", e); + } + } + /** Returns the next document operation, or null if we have reached the end */ public DocumentOperation next() { switch (state) { diff --git a/document/src/main/java/com/yahoo/document/json/LazyTokenBuffer.java b/document/src/main/java/com/yahoo/document/json/LazyTokenBuffer.java new file mode 100644 index 00000000000..0fbdd0b28c7 --- /dev/null +++ b/document/src/main/java/com/yahoo/document/json/LazyTokenBuffer.java @@ -0,0 +1,64 @@ +package com.yahoo.document.json; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; + +import java.io.IOException; +import java.util.function.Supplier; + +/** + * A {@link TokenBuffer} which only buffers tokens when needed, i.e., when peeking. + * + * @author jonmv + */ +public class LazyTokenBuffer extends TokenBuffer { + + private final JsonParser parser; + + public LazyTokenBuffer(JsonParser parser) { + this.parser = parser; + try { addFromParser(parser); } + catch (IOException e) { throw new IllegalArgumentException("failed parsing document JSON", e); } + if (JsonToken.START_OBJECT != current()) + throw new IllegalArgumentException("expected start of JSON object, but got " + current()); + updateNesting(current()); + } + + void advance() { + super.advance(); + if (tokens.isEmpty() && nesting() > 0) tokens.add(nextToken()); // Fill current token if needed and possible. + } + + @Override + public Supplier<Token> lookahead() { + return new Supplier<>() { + int localNesting = nesting(); + Supplier<Token> buffered = LazyTokenBuffer.super.lookahead(); + @Override public Token get() { + if (localNesting == 0) + return null; + + Token token = buffered.get(); + if (token == null) { + token = nextToken(); + tokens.add(token); + } + localNesting += nestingOffset(token.token); + return token; + } + }; + } + + private Token nextToken() { + try { + JsonToken token = parser.nextValue(); + if (token == null) + throw new IllegalStateException("no more JSON tokens"); + return new Token(token, parser.getCurrentName(), parser.getText()); + } + catch (IOException e) { + throw new IllegalArgumentException("failed reading document JSON", e); + } + } + +} diff --git a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java index a9cd3cc87a8..3a48f71c4cd 100644 --- a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java +++ b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java @@ -1,15 +1,16 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.document.json; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; - import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonToken; import com.google.common.base.Preconditions; +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Iterator; +import java.util.function.Supplier; + /** * Helper class to enable lookahead in the token stream. * @@ -17,101 +18,76 @@ import com.google.common.base.Preconditions; */ public class TokenBuffer { - private final List<Token> tokens; + final Deque<Token> tokens = new ArrayDeque<>(); - private int position = 0; private int nesting = 0; - public TokenBuffer() { - this(new ArrayList<>()); - } - - public TokenBuffer(List<Token> tokens) { - this.tokens = tokens; - if (tokens.size() > 0) - updateNesting(tokens.get(position).token); - } + public TokenBuffer() { } /** Returns whether any tokens are available in this */ - public boolean isEmpty() { return remaining() == 0; } - - public JsonToken previous() { - updateNestingGoingBackwards(current()); - position--; - return current(); - } - - /** Returns the current token without changing position, or null if none */ - public JsonToken current() { - if (isEmpty()) return null; - Token token = tokens.get(position); - if (token == null) return null; - return token.token; - } + public boolean isEmpty() { return tokens.isEmpty(); } + /** Returns the next token, or null, and updates the nesting count of this. */ public JsonToken next() { - position++; + advance(); JsonToken token = current(); updateNesting(token); return token; } - /** Returns a given number of tokens ahead, or null if none */ - public JsonToken peek(int ahead) { - if (tokens.size() <= position + ahead) return null; - return tokens.get(position + ahead).token; + void advance() { + tokens.poll(); + } + + /** Returns the current token without changing position, or null if none */ + public JsonToken current() { + return isEmpty() ? null : tokens.peek().token; } /** Returns the current token name without changing position, or null if none */ public String currentName() { - if (isEmpty()) return null; - Token token = tokens.get(position); - if (token == null) return null; - return token.name; + return isEmpty() ? null : tokens.peek().name; } /** Returns the current token text without changing position, or null if none */ public String currentText() { - if (isEmpty()) return null; - Token token = tokens.get(position); - if (token == null) return null; - return token.text; + return isEmpty() ? null : tokens.peek().text; } - public int remaining() { - return tokens.size() - position; + /** + * Returns a sequence of remaining tokens in this, or nulls when none remain. + * This may fill the token buffer, but not otherwise modify it. + */ + public Supplier<Token> lookahead() { + Iterator<Token> iterator = tokens.iterator(); + if (iterator.hasNext()) iterator.next(); + return () -> iterator.hasNext() ? iterator.next() : null; } private void add(JsonToken token, String name, String text) { - tokens.add(tokens.size(), new Token(token, name, text)); + tokens.add(new Token(token, name, text)); } - public void bufferObject(JsonToken first, JsonParser tokens) { - bufferJsonStruct(first, tokens, JsonToken.START_OBJECT); + public void bufferObject(JsonParser parser) { + bufferJsonStruct(parser, JsonToken.START_OBJECT); } - private void bufferJsonStruct(JsonToken first, JsonParser tokens, JsonToken firstToken) { - int localNesting = 0; - JsonToken t = first; + private void bufferJsonStruct(JsonParser parser, JsonToken firstToken) { + JsonToken token = parser.currentToken(); + Preconditions.checkArgument(token == firstToken, + "Expected %s, got %s.", firstToken.name(), token); + updateNesting(token); - Preconditions.checkArgument(first == firstToken, - "Expected %s, got %s.", firstToken.name(), t); - if (remaining() == 0) { - updateNesting(t); + try { + for (int nesting = addFromParser(parser); nesting > 0; nesting += addFromParser(parser)) + parser.nextValue(); } - localNesting = storeAndPeekNesting(t, localNesting, tokens); - while (localNesting > 0) { - t = nextValue(tokens); - localNesting = storeAndPeekNesting(t, localNesting, tokens); + catch (IOException e) { + throw new IllegalArgumentException(e); } } - private int storeAndPeekNesting(JsonToken t, int nesting, JsonParser tokens) { - addFromParser(t, tokens); - return nesting + nestingOffset(t); - } - - private int nestingOffset(JsonToken token) { + int nestingOffset(JsonToken token) { if (token == null) return 0; if (token.isStructStart()) { return 1; @@ -122,71 +98,23 @@ public class TokenBuffer { } } - private void addFromParser(JsonToken t, JsonParser tokens) { - try { - add(t, tokens.getCurrentName(), tokens.getText()); - } catch (IOException e) { - throw new IllegalArgumentException(e); - } - } - - private JsonToken nextValue(JsonParser tokens) { - try { - return tokens.nextValue(); - } catch (IOException e) { - throw new IllegalArgumentException(e); - } + int addFromParser(JsonParser tokens) throws IOException { + add(tokens.currentToken(), tokens.getCurrentName(), tokens.getText()); + return nestingOffset(tokens.currentToken()); } - private void updateNesting(JsonToken token) { + void updateNesting(JsonToken token) { nesting += nestingOffset(token); } - private void updateNestingGoingBackwards(JsonToken token) { - nesting -= nestingOffset(token); - } - public int nesting() { return nesting; } - public Token prefetchScalar(String name) { - int localNesting = nesting(); - int nestingBarrier = localNesting; - Token toReturn = null; - Iterator<Token> i; - - if (name.equals(currentName()) && current().isScalarValue()) { - toReturn = tokens.get(position); - } else { - i = tokens.iterator(); - i.next(); // just ignore the first value, as we know it's not what - // we're looking for, and it's nesting effect is already - // included - while (i.hasNext()) { - Token t = i.next(); - if (localNesting == nestingBarrier && name.equals(t.name) && t.token.isScalarValue()) { - toReturn = t; - break; - } - localNesting += nestingOffset(t.token); - if (localNesting < nestingBarrier) { - break; - } - } - } - return toReturn; - } - public void skipToRelativeNesting(int relativeNesting) { int initialNesting = nesting(); - do { - next(); - } while ( nesting() > initialNesting + relativeNesting); - } - - public List<Token> rest() { - return tokens.subList(position, tokens.size()); + do next(); + while (nesting() > initialNesting + relativeNesting); } public static final class Token { diff --git a/document/src/main/java/com/yahoo/document/json/document/DocumentParser.java b/document/src/main/java/com/yahoo/document/json/document/DocumentParser.java index 74656762fe1..77e11dcf2a8 100644 --- a/document/src/main/java/com/yahoo/document/json/document/DocumentParser.java +++ b/document/src/main/java/com/yahoo/document/json/document/DocumentParser.java @@ -20,7 +20,7 @@ public class DocumentParser { private static final String UPDATE = "update"; private static final String PUT = "put"; private static final String ID = "id"; - private static final String CONDITION = "condition"; + public static final String CONDITION = "condition"; public static final String CREATE_IF_NON_EXISTENT = "create"; public static final String FIELDS = "fields"; public static final String REMOVE = "remove"; @@ -86,16 +86,6 @@ public class DocumentParser { private void handleIdentLevelOne(DocumentParseInfo documentParseInfo, boolean docIdAndOperationIsSetExternally) throws IOException { JsonToken currentToken = parser.getCurrentToken(); - if (currentToken == JsonToken.VALUE_TRUE || currentToken == JsonToken.VALUE_FALSE) { - try { - if (CREATE_IF_NON_EXISTENT.equals(parser.getCurrentName())) { - documentParseInfo.create = Optional.ofNullable(parser.getBooleanValue()); - return; - } - } catch (IOException e) { - throw new RuntimeException("Got IO exception while parsing document", e); - } - } if ((currentToken == JsonToken.VALUE_TRUE || currentToken == JsonToken.VALUE_FALSE) && CREATE_IF_NON_EXISTENT.equals(parser.getCurrentName())) { documentParseInfo.create = Optional.of(currentToken == JsonToken.VALUE_TRUE); @@ -111,12 +101,11 @@ public class DocumentParser { } } - private void handleIdentLevelTwo(DocumentParseInfo documentParseInfo) { + private void handleIdentLevelTwo(DocumentParseInfo documentParseInfo) { try { - JsonToken currentToken = parser.getCurrentToken(); // "fields" opens a dictionary and is therefore on level two which might be surprising. - if (currentToken == JsonToken.START_OBJECT && FIELDS.equals(parser.getCurrentName())) { - documentParseInfo.fieldsBuffer.bufferObject(currentToken, parser); + if (parser.currentToken() == JsonToken.START_OBJECT && FIELDS.equals(parser.getCurrentName())) { + documentParseInfo.fieldsBuffer.bufferObject(parser); processIndent(); } } catch (IOException e) { diff --git a/document/src/main/java/com/yahoo/document/json/readers/DocumentParseInfo.java b/document/src/main/java/com/yahoo/document/json/readers/DocumentParseInfo.java index 2dce07cdbe6..e859306f04d 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/DocumentParseInfo.java +++ b/document/src/main/java/com/yahoo/document/json/readers/DocumentParseInfo.java @@ -8,6 +8,7 @@ import com.yahoo.document.json.TokenBuffer; import java.util.Optional; public class DocumentParseInfo { + public DocumentParseInfo() { } public DocumentId documentId; public Optional<Boolean> create = Optional.empty(); public Optional<String> condition = Optional.empty(); diff --git a/document/src/main/java/com/yahoo/document/json/readers/MapReader.java b/document/src/main/java/com/yahoo/document/json/readers/MapReader.java index 6c850fe4320..b45a0001fd1 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/MapReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/MapReader.java @@ -90,45 +90,39 @@ public class MapReader { @SuppressWarnings({ "rawtypes", "unchecked" }) public static ValueUpdate createMapUpdate(TokenBuffer buffer, DataType currentLevel, - FieldValue keyParent, - FieldValue topLevelKey, boolean ignoreUndefinedFields) { - TokenBuffer.Token element = buffer.prefetchScalar(UPDATE_ELEMENT); + if ( ! JsonToken.START_OBJECT.equals(buffer.current())) + throw new IllegalArgumentException("Expected object for match update, got " + buffer.current()); + buffer.next(); + + FieldValue key = null; + ValueUpdate update; + if (UPDATE_ELEMENT.equals(buffer.currentName())) { + key = keyTypeForMapUpdate(buffer.currentText(), currentLevel); buffer.next(); } - FieldValue key = keyTypeForMapUpdate(element, currentLevel); - if (keyParent != null) { - ((CollectionFieldValue) keyParent).add(key); - } - // structure is: [(match + element)*, (element + action)] - // match will always have element, and either match or action - if (!UPDATE_MATCH.equals(buffer.currentName())) { - // we have reached an action... - if (topLevelKey == null) { - return ValueUpdate.createMap(key, readSingleUpdate(buffer, valueTypeForMapUpdate(currentLevel), buffer.currentName(), ignoreUndefinedFields)); - } else { - return ValueUpdate.createMap(topLevelKey, readSingleUpdate(buffer, valueTypeForMapUpdate(currentLevel), buffer.currentName(), ignoreUndefinedFields)); - } - } else { - // next level of matching - if (topLevelKey == null) { - return createMapUpdate(buffer, valueTypeForMapUpdate(currentLevel), key, key, ignoreUndefinedFields); - } else { - return createMapUpdate(buffer, valueTypeForMapUpdate(currentLevel), key, topLevelKey, ignoreUndefinedFields); - } + update = UPDATE_MATCH.equals(buffer.currentName()) ? createMapUpdate(buffer, valueTypeForMapUpdate(currentLevel), ignoreUndefinedFields) + : readSingleUpdate(buffer, valueTypeForMapUpdate(currentLevel), buffer.currentName(), ignoreUndefinedFields); + buffer.next(); + + if (key == null) { + if ( ! UPDATE_ELEMENT.equals(buffer.currentName())) + throw new IllegalArgumentException("Expected match element, got " + buffer.current()); + key = keyTypeForMapUpdate(buffer.currentText(), currentLevel); + buffer.next(); } + + if ( ! JsonToken.END_OBJECT.equals(buffer.current())) + throw new IllegalArgumentException("Expected object end for match update, got " + buffer.current()); + + return ValueUpdate.createMap(key, update); } @SuppressWarnings("rawtypes") public static ValueUpdate createMapUpdate(TokenBuffer buffer, Field field, boolean ignoreUndefinedFields) { - buffer.next(); - MapValueUpdate m = (MapValueUpdate) MapReader.createMapUpdate(buffer, field.getDataType(), null, null, ignoreUndefinedFields); - buffer.next(); - // must generate the field value in parallell with the actual - return m; - + return MapReader.createMapUpdate(buffer, field.getDataType(), ignoreUndefinedFields); } private static DataType valueTypeForMapUpdate(DataType parentType) { @@ -143,14 +137,14 @@ public class MapReader { } } - private static FieldValue keyTypeForMapUpdate(TokenBuffer.Token element, DataType expectedType) { + private static FieldValue keyTypeForMapUpdate(String elementText, DataType expectedType) { FieldValue v; if (expectedType instanceof ArrayDataType) { - v = new IntegerFieldValue(Integer.valueOf(element.text)); + v = new IntegerFieldValue(Integer.valueOf(elementText)); } else if (expectedType instanceof WeightedSetDataType) { - v = ((WeightedSetDataType) expectedType).getNestedType().createFieldValue(element.text); + v = ((WeightedSetDataType) expectedType).getNestedType().createFieldValue(elementText); } else if (expectedType instanceof MapDataType) { - v = ((MapDataType) expectedType).getKeyType().createFieldValue(element.text); + v = ((MapDataType) expectedType).getKeyType().createFieldValue(elementText); } else { throw new IllegalArgumentException("Container type " + expectedType + " not supported for match update."); } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index 0b7b1ae9996..1fd4029b1a5 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -4,13 +4,15 @@ package com.yahoo.document.json.readers; import com.fasterxml.jackson.core.JsonToken; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.json.TokenBuffer; -import com.yahoo.slime.Inspector; -import com.yahoo.slime.Type; +import com.yahoo.document.json.TokenBuffer.Token; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TensorType.Dimension; + +import java.util.function.Supplier; import static com.yahoo.document.json.readers.JsonParserHelpers.*; import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString; @@ -37,36 +39,43 @@ public class TensorReader { Tensor.Builder builder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType()); expectOneOf(buffer.current(), JsonToken.START_OBJECT, JsonToken.START_ARRAY); int initNesting = buffer.nesting(); - for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { - if (TENSOR_CELLS.equals(buffer.currentName()) && ! primitiveContent(buffer)) { + while (true) { + Supplier<Token> lookahead = buffer.lookahead(); + Token next = lookahead.get(); + if (TENSOR_CELLS.equals(next.name) && ! primitiveContent(next.token, lookahead.get().token)) { + buffer.next(); readTensorCells(buffer, builder); } - else if (TENSOR_VALUES.equals(buffer.currentName()) && builder.type().dimensions().stream().allMatch(d -> d.isIndexed())) { + else if (TENSOR_VALUES.equals(next.name) && builder.type().dimensions().stream().allMatch(Dimension::isIndexed)) { + buffer.next(); readTensorValues(buffer, builder); } - else if (TENSOR_BLOCKS.equals(buffer.currentName())) { + else if (TENSOR_BLOCKS.equals(next.name)) { + buffer.next(); readTensorBlocks(buffer, builder); } - else if (TENSOR_TYPE.equals(buffer.currentName()) && buffer.current() == JsonToken.VALUE_STRING) { + else if (TENSOR_TYPE.equals(next.name) && next.token == JsonToken.VALUE_STRING) { + buffer.next(); // Ignore input tensor type } + else if (buffer.nesting() == initNesting && JsonToken.END_OBJECT == next.token) { + buffer.next(); + break; + } else { - buffer.previous(); // Back up to the start of the enclosing block readDirectTensorValue(buffer, builder); - buffer.previous(); // ... and back up to the end of the enclosing block + break; } } expectOneOf(buffer.current(), JsonToken.END_OBJECT, JsonToken.END_ARRAY); tensorFieldValue.assign(builder.build()); } - static boolean primitiveContent(TokenBuffer buffer) { - JsonToken cellsValue = buffer.current(); - if (cellsValue.isScalarValue()) return true; - if (cellsValue == JsonToken.START_ARRAY) { - JsonToken firstArrayValue = buffer.peek(1); - if (firstArrayValue == JsonToken.END_ARRAY) return false; - if (firstArrayValue.isScalarValue()) return true; + static boolean primitiveContent(JsonToken current, JsonToken next) { + if (current.isScalarValue()) return true; + if (current == JsonToken.START_ARRAY) { + if (next == JsonToken.END_ARRAY) return false; + if (next.isScalarValue()) return true; } return false; } @@ -186,7 +195,7 @@ public class TensorReader { boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); - if (isArrayOfObjects(buffer, 0)) + if (isArrayOfObjects(buffer)) readTensorCells(buffer, builder); else if ( ! hasMapped) readTensorValues(buffer, builder); @@ -196,10 +205,12 @@ public class TensorReader { readTensorCells(buffer, builder); } - private static boolean isArrayOfObjects(TokenBuffer buffer, int ahead) { - if (buffer.peek(ahead++) != JsonToken.START_ARRAY) return false; - if (buffer.peek(ahead) == JsonToken.START_ARRAY) return isArrayOfObjects(buffer, ahead); // nested array - return buffer.peek(ahead) == JsonToken.START_OBJECT; + private static boolean isArrayOfObjects(TokenBuffer buffer) { + if (buffer.current() != JsonToken.START_ARRAY) return false; + Supplier<Token> lookahead = buffer.lookahead(); + Token next; + while ((next = lookahead.get()).token == JsonToken.START_ARRAY) { } + return next.token == JsonToken.START_OBJECT; } private static TensorAddress readAddress(TokenBuffer buffer, TensorType type) { diff --git a/document/src/main/java/com/yahoo/document/json/readers/VespaJsonDocumentReader.java b/document/src/main/java/com/yahoo/document/json/readers/VespaJsonDocumentReader.java index 113b8732b23..067dabdbdab 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/VespaJsonDocumentReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/VespaJsonDocumentReader.java @@ -230,7 +230,7 @@ public class VespaJsonDocumentReader { private static boolean isFieldPath(String field) { - return field.matches("^.*?[.\\[\\{].*$"); + return field.matches("^.*?[.\\[{].*$"); } private static void verifyEndState(TokenBuffer buffer, JsonToken expectedFinalToken) { @@ -238,7 +238,7 @@ public class VespaJsonDocumentReader { "Expected end of JSON struct (%s), got %s", expectedFinalToken, buffer.current()); Preconditions.checkState(buffer.nesting() == 0, "Nesting not zero at end of operation"); Preconditions.checkState(buffer.next() == null, "Dangling data at end of operation"); - Preconditions.checkState(buffer.remaining() == 0, "Dangling data at end of operation"); + Preconditions.checkState(buffer.isEmpty(), "Dangling data at end of operation"); } } diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index ef2b40c962d..7e15a729684 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -119,7 +119,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { for (int i = 0; i < type.dimensions().size(); ++i) { var dim = type.dimensions().get(i); if (dim.isMapped()) { - builder.add(dim.name(), address.label(i)); + builder.add(dim.name(), (int) address.numericLabel(i)); } } return builder.build(); diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 8a45fe95fa2..aa043a25d78 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -20,6 +20,7 @@ import com.yahoo.document.MapDataType; import com.yahoo.document.PositionDataType; import com.yahoo.document.StructDataType; import com.yahoo.document.TensorDataType; +import com.yahoo.document.TestAndSetCondition; import com.yahoo.document.WeightedSetDataType; import com.yahoo.document.datatypes.Array; import com.yahoo.document.datatypes.BoolFieldValue; @@ -31,6 +32,7 @@ import com.yahoo.document.datatypes.StringFieldValue; import com.yahoo.document.datatypes.Struct; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.datatypes.WeightedSet; +import com.yahoo.document.fieldpathupdate.FieldPathUpdate; import com.yahoo.document.internal.GeoPosType; import com.yahoo.document.json.readers.DocumentParseInfo; import com.yahoo.document.json.readers.VespaJsonDocumentReader; @@ -62,6 +64,7 @@ import org.junit.Test; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Collections; @@ -82,6 +85,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -151,6 +155,13 @@ public class JsonReaderTestCase { types.registerDocumentType(x); } { + DocumentType x = new DocumentType("testArrayOfArrayOfInt"); + DataType inner = new ArrayDataType(DataType.INT); + DataType outer = new ArrayDataType(inner); + x.addField(new Field("arrayOfArrayOfInt", outer)); + types.registerDocumentType(x); + } + { DocumentType x = new DocumentType("testsinglepos"); DataType d = PositionDataType.INSTANCE; x.addField(new Field("singlepos", d)); @@ -211,103 +222,169 @@ public class JsonReaderTestCase { } @Test - public void readSingleDocumentPut() { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:smoke::doc1',", - " 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse'", - " }", - "}")); - DocumentPut put = (DocumentPut) r.readSingleDocument(DocumentOperationType.PUT, - "id:unittest:smoke::doc1").operation(); + public void readDocumentWithMissingFieldsField() { + assertEquals("document is missing the required \"fields\" field", + assertThrows(IllegalArgumentException.class, + () -> createReader("{ }").readSingleDocumentStreaming(DocumentOperationType.PUT, + "id:unittest:testnull::whee")) + .getMessage()); + } + + @Test + public void readSingleDocumentsPutStreaming() throws IOException { + String json = """ + { + "remove": "id:unittest:smoke::ignored", + "ignored-extra-array": [{ "foo": null }, { }], + "ignored-extra-object": { "foo": [null, { }], "bar": { } }, + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + }, + "id": "id:unittest:smoke::ignored", + "create": false, + "condition": "true" + } + """; + ParsedDocumentOperation operation = createReader(json).readSingleDocumentStreaming(DocumentOperationType.PUT,"id:unittest:smoke::doc1"); + DocumentPut put = ((DocumentPut) operation.operation()); + assertFalse(put.getCreateIfNonExistent()); + assertEquals("true", put.getCondition().getSelection()); smokeTestDoc(put.getDocument()); } @Test - public final void readSingleDocumentUpdate() { - JsonReader r = createReader(inputJson("{ 'update': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': {", - " 'assign': 'orOther' }}}")); - DocumentUpdate doc = (DocumentUpdate) r.readSingleDocument(DocumentOperationType.UPDATE, "id:unittest:smoke::whee").operation(); + public void readSingleDocumentsUpdateStreaming() throws IOException { + String json = """ + { + "remove": "id:unittest:smoke::ignored", + "ignored-extra-array": [{ "foo": null }, { }], + "ignored-extra-object": { "foo": [null, { }], "bar": { } }, + "fields": { + "something": { "assign": "smoketest" }, + "flag": { "assign": true }, + "nalle": { "assign": "bamse" } + }, + "id": "id:unittest:smoke::ignored", + "create": true, + "condition": "false" + } + """; + ParsedDocumentOperation operation = createReader(json).readSingleDocumentStreaming(DocumentOperationType.UPDATE,"id:unittest:smoke::doc1"); + Document doc = new Document(types.getDocumentType("smoke"), new DocumentId("id:unittest:smoke::doc1")); + DocumentUpdate update = ((DocumentUpdate) operation.operation()); + update.applyTo(doc); + smokeTestDoc(doc); + assertTrue(update.getCreateIfNonExistent()); + assertEquals("false", update.getCondition().getSelection()); + } + + @Test + public void readSingleDocumentPut() throws IOException { + Document doc = docFromJson(""" + { + "put": "id:unittest:smoke::doc1", + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + } + } + """); + smokeTestDoc(doc); + } + + @Test + public final void readSingleDocumentUpdate() throws IOException { + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:smoke::whee", + "fields": { + "something": { + "assign": "orOther" + } + } + } + """); FieldUpdate f = doc.getFieldUpdate("something"); assertEquals(1, f.size()); assertTrue(f.getValueUpdate(0) instanceof AssignValueUpdate); + assertEquals(new StringFieldValue("orOther"), f.getValueUpdate(0).getValue()); } @Test - public void readClearField() { - JsonReader r = createReader(inputJson("{ 'update': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'int1': {", - " 'assign': null }}}")); - DocumentUpdate doc = (DocumentUpdate) r.readSingleDocument(DocumentOperationType.UPDATE, "id:unittest:smoke::whee").operation(); + public void readClearField() throws IOException { + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:smoke::whee", + "fields": { + "int1": { + "assign": null + } + } + } + """); FieldUpdate f = doc.getFieldUpdate("int1"); assertEquals(1, f.size()); assertTrue(f.getValueUpdate(0) instanceof ClearValueUpdate); assertNull(f.getValueUpdate(0).getValue()); } - @Test public void smokeTest() throws IOException { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:smoke::doc1',", - " 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse'", - " }", - "}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - smokeTestDoc(put.getDocument()); + Document doc = docFromJson(""" + { + "put": "id:unittest:smoke::doc1", + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + } + } + """); + smokeTestDoc(doc); } @Test public void docIdLookaheadTest() throws IOException { - JsonReader r = createReader(inputJson( - "{ 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse'", - " },", - " 'put': 'id:unittest:smoke::doc1'", - " }", - "}")); - - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - smokeTestDoc(put.getDocument()); + Document doc = docFromJson(""" + { + "put": "id:unittest:smoke::doc1", + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + } + } + """); + smokeTestDoc(doc); } - @Test public void emptyDocTest() throws IOException { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:smoke::whee', 'fields': {}}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - assertEquals("id:unittest:smoke::whee", parseInfo.documentId.toString()); + Document doc = docFromJson(""" + { + "put": "id:unittest:smoke::whee", + "fields": { } + }"""); + assertEquals(new Document(types.getDocumentType("smoke"), new DocumentId("id:unittest:smoke::whee")), + doc); } @Test public void testStruct() throws IOException { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:mirrors::whee',", - " 'fields': {", - " 'skuggsjaa': {", - " 'sandra': 'person',", - " 'cloud': 'another person' }}}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - Document doc = put.getDocument(); + Document doc = docFromJson(""" + { + "put": "id:unittest:mirrors::whee", + "fields": { + "skuggsjaa": { + "sandra": "person", + "cloud": "another person" + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("skuggsjaa")); assertSame(Struct.class, f.getClass()); Struct s = (Struct) f; @@ -326,13 +403,20 @@ public class JsonReaderTestCase { @Test public void testStructUpdate() throws IOException { - DocumentUpdate put = parseUpdate(inputJson("{ 'update': 'id:unittest:mirrors:g=test:whee',", - " 'create': true,", - " 'fields': {", - " 'skuggsjaa': {", - " 'assign': {", - " 'sandra': 'person',", - " 'cloud': 'another person' }}}}")); + DocumentUpdate put = parseUpdate(""" + { + "update": "id:unittest:mirrors:g=test:whee", + "create": true, + "fields": { + "skuggsjaa": { + "assign": { + "sandra": "person", + "cloud": "another person" + } + } + } + } + """); assertEquals(1, put.fieldUpdates().size()); FieldUpdate fu = put.fieldUpdates().iterator().next(); assertEquals(1, fu.getValueUpdates().size()); @@ -351,11 +435,17 @@ public class JsonReaderTestCase { @Test public final void testEmptyStructUpdate() throws IOException { - DocumentUpdate put = parseUpdate(inputJson("{ 'update': 'id:unittest:mirrors:g=test:whee',", - " 'create': true,", - " 'fields': { ", - " 'skuggsjaa': {", - " 'assign': { } }}}")); + DocumentUpdate put = parseUpdate(""" + { + "update": "id:unittest:mirrors:g=test:whee", + "create": true, + "fields": { + "skuggsjaa": { + "assign": { } + } + } + } + """); assertEquals(1, put.fieldUpdates().size()); FieldUpdate fu = put.fieldUpdates().iterator().next(); assertEquals(1, fu.getValueUpdates().size()); @@ -373,23 +463,37 @@ public class JsonReaderTestCase { @Test public void testUpdateArray() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testarray::whee',", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person' ]}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testarray::whee", + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + } + """); checkSimpleArrayAdd(doc); } @Test public void testUpdateWeighted() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testset::whee',", - " 'fields': {", - " 'actualset': {", - " 'add': {", - " 'person': 37,", - " 'another person': 41 }}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testset::whee", + "fields": { + "actualset": { + "add": { + "person": 37, + "another person": 41 + } + } + } + } + """); Map<String, Integer> weights = new HashMap<>(); FieldUpdate x = doc.getFieldUpdate("actualset"); @@ -409,12 +513,34 @@ public class JsonReaderTestCase { @Test public void testUpdateMatch() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testset::whee',", - " 'fields': {", - " 'actualset': {", - " 'match': {", - " 'element': 'person',", - " 'increment': 13 }}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testset::whee", + "fields": { + "actualset": { + "match": { + "element": "person", + "increment": 13 + } + } + } + } + """); + + DocumentUpdate otherDoc = parseUpdate(""" + { + "update": "id:unittest:testset::whee", + "fields": { + "actualset": { + "match": { + "increment": 13, + "element": "person" + } + } + } + }"""); + + assertEquals(doc, otherDoc); Map<String, Tuple2<Number, String>> matches = new HashMap<>(); FieldUpdate x = doc.getFieldUpdate("actualset"); @@ -437,21 +563,28 @@ public class JsonReaderTestCase { @Test public void testArithmeticOperators() throws IOException { Tuple2[] operations = new Tuple2[] { - new Tuple2<String, Operator>(UPDATE_DECREMENT, - ArithmeticValueUpdate.Operator.SUB), - new Tuple2<String, Operator>(UPDATE_DIVIDE, + new Tuple2<>(UPDATE_DECREMENT, + ArithmeticValueUpdate.Operator.SUB), + new Tuple2<>(UPDATE_DIVIDE, ArithmeticValueUpdate.Operator.DIV), - new Tuple2<String, Operator>(UPDATE_INCREMENT, + new Tuple2<>(UPDATE_INCREMENT, ArithmeticValueUpdate.Operator.ADD), - new Tuple2<String, Operator>(UPDATE_MULTIPLY, + new Tuple2<>(UPDATE_MULTIPLY, ArithmeticValueUpdate.Operator.MUL) }; for (Tuple2<String, Operator> operator : operations) { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testset::whee',", - " 'fields': {", - " 'actualset': {", - " 'match': {", - " 'element': 'person',", - " '" + (String) operator.first + "': 13 }}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testset::whee", + "fields": { + "actualset": { + "match": { + "element": "person", + "%s": 13 + } + } + } + } + """.formatted(operator.first)); Map<String, Tuple2<Number, Operator>> matches = new HashMap<>(); FieldUpdate x = doc.getFieldUpdate("actualset"); @@ -475,12 +608,19 @@ public class JsonReaderTestCase { @SuppressWarnings("rawtypes") @Test public void testArrayIndexing() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testarray::whee',", - " 'fields': {", - " 'actualarray': {", - " 'match': {", - " 'element': 3,", - " 'assign': 'nalle' }}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testarray::whee", + "fields": { + "actualarray": { + "match": { + "element": 3, + "assign": "nalle" + } + } + } + } + """); Map<Number, String> matches = new HashMap<>(); FieldUpdate x = doc.getFieldUpdate("actualarray"); @@ -488,7 +628,7 @@ public class JsonReaderTestCase { MapValueUpdate adder = (MapValueUpdate) v; final Number key = ((IntegerFieldValue) adder.getValue()) .getNumber(); - String op = ((StringFieldValue) ((AssignValueUpdate) adder.getUpdate()) + String op = ((StringFieldValue) adder.getUpdate() .getValue()).getString(); matches.put(key, op); } @@ -515,11 +655,17 @@ public class JsonReaderTestCase { @Test public void testWeightedSet() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testset::whee',", - " 'fields': {", - " 'actualset': {", - " 'nalle': 2,", - " 'tralle': 7 }}}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testset::whee", + "fields": { + "actualset": { + "nalle": 2, + "tralle": 7 + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("actualset")); assertSame(WeightedSet.class, f.getClass()); WeightedSet<?> w = (WeightedSet<?>) f; @@ -530,11 +676,17 @@ public class JsonReaderTestCase { @Test public void testArray() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testarray::whee',", - " 'fields': {", - " 'actualarray': [", - " 'nalle',", - " 'tralle' ]}}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testarray::whee", + "fields": { + "actualarray": [ + "nalle", + "tralle" + ] + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("actualarray")); assertSame(Array.class, f.getClass()); Array<?> a = (Array<?>) f; @@ -545,11 +697,17 @@ public class JsonReaderTestCase { @Test public void testMap() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testmap::whee',", - " 'fields': {", - " 'actualmap': {", - " 'nalle': 'kalle',", - " 'tralle': 'skalle' }}}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testmap::whee", + "fields": { + "actualmap": { + "nalle": "kalle", + "tralle": "skalle" + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("actualmap")); assertSame(MapFieldValue.class, f.getClass()); MapFieldValue<?, ?> m = (MapFieldValue<?, ?>) f; @@ -560,11 +718,23 @@ public class JsonReaderTestCase { @Test public void testOldMap() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testmap::whee',", - " 'fields': {", - " 'actualmap': [", - " { 'key': 'nalle', 'value': 'kalle'},", - " { 'key': 'tralle', 'value': 'skalle'} ]}}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testmap::whee", + "fields": { + "actualmap": [ + { + "key": "nalle", + "value": "kalle" + }, + { + "key": "tralle", + "value": "skalle" + } + ] + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("actualmap")); assertSame(MapFieldValue.class, f.getClass()); MapFieldValue<?, ?> m = (MapFieldValue<?, ?>) f; @@ -575,9 +745,14 @@ public class JsonReaderTestCase { @Test public void testPositionPositive() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'singlepos': 'N63.429722;E10.393333' }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "singlepos": "N63.429722;E10.393333" + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("singlepos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -586,9 +761,17 @@ public class JsonReaderTestCase { @Test public void testPositionOld() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'singlepos': {'x':10393333,'y':63429722} }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "singlepos": { + "x": 10393333, + "y": 63429722 + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("singlepos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -597,9 +780,17 @@ public class JsonReaderTestCase { @Test public void testGeoPosition() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'singlepos': {'lat':63.429722,'lng':10.393333} }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "singlepos": { + "lat": 63.429722, + "lng": 10.393333 + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("singlepos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -608,9 +799,17 @@ public class JsonReaderTestCase { @Test public void testGeoPositionNoAbbreviations() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'singlepos': {'latitude':63.429722,'longitude':10.393333} }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "singlepos": { + "latitude": 63.429722, + "longitude": 10.393333 + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("singlepos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -619,9 +818,14 @@ public class JsonReaderTestCase { @Test public void testPositionGeoPos() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'geopos': 'N63.429722;E10.393333' }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "geopos": "N63.429722;E10.393333" + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("geopos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -631,9 +835,17 @@ public class JsonReaderTestCase { @Test public void testPositionOldGeoPos() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'geopos': {'x':10393333,'y':63429722} }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "geopos": { + "x": 10393333, + "y": 63429722 + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("geopos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -643,9 +855,17 @@ public class JsonReaderTestCase { @Test public void testGeoPositionGeoPos() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'geopos': {'lat':63.429722,'lng':10.393333} }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "geopos": { + "lat": 63.429722, + "lng": 10.393333 + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("geopos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -656,9 +876,14 @@ public class JsonReaderTestCase { @Test public void testPositionNegative() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'singlepos': 'W46.63;S23.55' }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "singlepos": "W46.63;S23.55" + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("singlepos")); assertSame(Struct.class, f.getClass()); assertEquals(-46630000, PositionDataType.getXValue(f).getInteger()); @@ -682,14 +907,14 @@ public class JsonReaderTestCase { } private String fieldStringFromBase64RawContent(String base64data) throws IOException { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:testraw::whee',", - " 'fields': {", - " 'actualraw': '" + base64data + "' }}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - Document doc = put.getDocument(); + Document doc = docFromJson(""" + { + "put": "id:unittest:testraw::whee", + "fields": { + "actualraw": "%s" + } + } + """.formatted(base64data)); FieldValue f = doc.getFieldValue(doc.getField("actualraw")); assertSame(Raw.class, f.getClass()); Raw s = (Raw) f; @@ -698,15 +923,16 @@ public class JsonReaderTestCase { @Test public void testMapStringToArrayOfInt() throws IOException { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:testMapStringToArrayOfInt::whee',", - " 'fields': {", - " 'actualMapStringToArrayOfInt': {", - " 'bamse': [1, 2, 3] }}}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - Document doc = put.getDocument(); + Document doc = docFromJson(""" + { + "put": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": { + "bamse": [1, 2, 3] + } + } + } + """); FieldValue f = doc.getFieldValue("actualMapStringToArrayOfInt"); assertSame(MapFieldValue.class, f.getClass()); MapFieldValue<?, ?> m = (MapFieldValue<?, ?>) f; @@ -719,15 +945,19 @@ public class JsonReaderTestCase { @Test public void testOldMapStringToArrayOfInt() throws IOException { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:testMapStringToArrayOfInt::whee',", - " 'fields': {", - " 'actualMapStringToArrayOfInt': [", - " { 'key': 'bamse', 'value': [1, 2, 3] } ]}}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - Document doc = put.getDocument(); + Document doc = docFromJson(""" + { + "put": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": [ + { + "key": "bamse", + "value": [1, 2, 3] + } + ] + } + } + """); FieldValue f = doc.getFieldValue("actualMapStringToArrayOfInt"); assertSame(MapFieldValue.class, f.getClass()); MapFieldValue<?, ?> m = (MapFieldValue<?, ?>) f; @@ -740,10 +970,16 @@ public class JsonReaderTestCase { @Test public void testAssignToString() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': {", - " 'assign': 'orOther' }}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:smoke::whee", + "fields": { + "something": { + "assign": "orOther" + } + } + } + """); FieldUpdate f = doc.getFieldUpdate("something"); assertEquals(1, f.size()); AssignValueUpdate a = (AssignValueUpdate) f.getValueUpdate(0); @@ -751,11 +987,189 @@ public class JsonReaderTestCase { } @Test + public void testNestedArrayMatch() throws IOException { + DocumentUpdate nested = parseUpdate(""" + { + "update": "id:unittest:testArrayOfArrayOfInt::whee", + "fields": { + "arrayOfArrayOfInt": { + "match": { + "element": 1, + "match": { + "element": 2, + "assign": 3 + } + } + } + } + } + """); + + DocumentUpdate equivalent = parseUpdate(""" + { + "update": "id:unittest:testArrayOfArrayOfInt::whee", + "fields": { + "arrayOfArrayOfInt": { + "match": { + "match": { + "assign": 3, + "element": 2 + }, + "element": 1 + } + } + } + } + """); + + assertEquals(nested, equivalent); + assertEquals(1, nested.fieldUpdates().size()); + FieldUpdate fu = nested.fieldUpdates().iterator().next(); + assertEquals(1, fu.getValueUpdates().size()); + MapValueUpdate mvu = (MapValueUpdate) fu.getValueUpdate(0); + assertEquals(new IntegerFieldValue(1), mvu.getValue()); + MapValueUpdate nvu = (MapValueUpdate) mvu.getUpdate(); + assertEquals(new IntegerFieldValue(2), nvu.getValue()); + AssignValueUpdate avu = (AssignValueUpdate) nvu.getUpdate(); + assertEquals(new IntegerFieldValue(3), avu.getValue()); + + Document doc = docFromJson(""" + { + "put": "id:unittest:testArrayOfArrayOfInt::whee", + "fields": { + "arrayOfArrayOfInt": [ + [1, 2, 3], + [4, 5, 6] + ] + } + } + """); + nested.applyTo(doc); + Document expected = docFromJson(""" + { + "put": "id:unittest:testArrayOfArrayOfInt::whee", + "fields": { + "arrayOfArrayOfInt": [ + [1, 2, 3], + [4, 5, 3] + ] + } + } + """); + assertEquals(expected, doc); + } + + @Test + public void testMatchCannotUpdateNestedFields() { + // Should this work? It doesn't. + assertEquals("Field type Map<string,Array<int>> not supported.", + assertThrows(UnsupportedOperationException.class, + () -> parseUpdate(""" + { + "update": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": { + "match": { + "element": "bamse", + "match": { + "element": 1, + "assign": 4 + } + } + } + } + } + """)).getMessage()); + } + + @Test + public void testMatchCannotAssignToNestedMap() { + // Unsupported value type for map value assign. + assertEquals("Field type Map<string,Array<int>> not supported.", + assertThrows(UnsupportedOperationException.class, + () -> parseUpdate(""" + { + "update": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": { + "match": { + "element": "bamse", + "assign": [1, 3, 4] + } + } + } + } + """)).getMessage()); + } + + @Test + public void testMatchCannotAssignToMap() { + // Unsupported value type for map value assign. + assertEquals("Field type Map<string,string> not supported.", + assertThrows(UnsupportedOperationException.class, + () -> parseUpdate(""" + { + "update": "id:unittest:testmap::whee", + "fields": { + "actualmap": { + "match": { + "element": "bamse", + "assign": "bar" + } + } + } + } + """)).getMessage()); + } + + + + @Test + public void testAssignInsideArrayInMap() throws IOException { + Document doc = docFromJson(""" + { + "put": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": { + "bamse": [1, 2, 3] + } + } + }"""); + + assertEquals(2, ((MapFieldValue<StringFieldValue, Array<IntegerFieldValue>>) doc.getFieldValue("actualMapStringToArrayOfInt")) + .get(StringFieldValue.getFactory().create("bamse")).get(1).getInteger()); + + DocumentUpdate update = parseUpdate(""" + { + "update": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt{bamse}[1]": { + "assign": 4 + } + } + } + """); + assertEquals(1, update.fieldPathUpdates().size()); + + update.applyTo(doc); + assertEquals(4, ((MapFieldValue<StringFieldValue, Array<IntegerFieldValue>>) doc.getFieldValue("actualMapStringToArrayOfInt")) + .get(StringFieldValue.getFactory().create("bamse")).get(1).getInteger()); + } + + @Test public void testAssignToArray() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testMapStringToArrayOfInt::whee',", - " 'fields': {", - " 'actualMapStringToArrayOfInt': {", - " 'assign': { 'bamse': [1, 2, 3] }}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": { + "assign": { + "bamse": [1, 2, 3] + } + } + } + } + """); FieldUpdate f = doc.getFieldUpdate("actualMapStringToArrayOfInt"); assertEquals(1, f.size()); AssignValueUpdate assign = (AssignValueUpdate) f.getValueUpdate(0); @@ -769,11 +1183,21 @@ public class JsonReaderTestCase { @Test public void testOldAssignToArray() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testMapStringToArrayOfInt::whee',", - " 'fields': {", - " 'actualMapStringToArrayOfInt': {", - " 'assign': [", - " { 'key': 'bamse', 'value': [1, 2, 3] } ]}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": { + "assign": [ + { + "key": "bamse", + "value": [1, 2, 3] + } + ] + } + } + } + """); FieldUpdate f = doc.getFieldUpdate("actualMapStringToArrayOfInt"); assertEquals(1, f.size()); AssignValueUpdate assign = (AssignValueUpdate) f.getValueUpdate(0); @@ -787,12 +1211,19 @@ public class JsonReaderTestCase { @Test public void testAssignToWeightedSet() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testset::whee',", - " 'fields': {", - " 'actualset': {", - " 'assign': {", - " 'person': 37,", - " 'another person': 41 }}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testset::whee", + "fields": { + "actualset": { + "assign": { + "person": 37, + "another person": 41 + } + } + } + } + """); FieldUpdate x = doc.getFieldUpdate("actualset"); assertEquals(1, x.size()); AssignValueUpdate assign = (AssignValueUpdate) x.getValueUpdate(0); @@ -805,41 +1236,66 @@ public class JsonReaderTestCase { @Test public void testCompleteFeed() { - JsonReader r = createReader(inputJson("[", - "{ 'put': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse' }},", - "{ 'update': 'id:unittest:testarray::whee',", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person' ]}}},", - "{ 'remove': 'id:unittest:smoke::whee' }]")); + JsonReader r = createReader(""" + [ + { + "put": "id:unittest:smoke::whee", + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + } + }, + { + "update": "id:unittest:testarray::whee", + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + }, + { + "remove": "id:unittest:smoke::whee" + } + ] + """); controlBasicFeed(r); } @Test public void testCompleteFeedWithCreateAndCondition() { - JsonReader r = createReader(inputJson("[", - "{ 'put': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse' }},", - "{", - " 'condition':'bla',", - " 'update': 'id:unittest:testarray::whee',", - " 'create':true,", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person' ]}}},", - "{ 'remove': 'id:unittest:smoke::whee' }]")); + JsonReader r = createReader(""" + [ + { + "put": "id:unittest:smoke::whee", + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + } + }, + { + "condition":"bla", + "update": "id:unittest:testarray::whee", + "create":true, + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + }, + { + "remove": "id:unittest:smoke::whee" + } + ] + """); DocumentOperation d = r.next(); Document doc = ((DocumentPut) d).getDocument(); @@ -860,7 +1316,7 @@ public class JsonReaderTestCase { @Test public void testUpdateWithConditionAndCreateInDifferentOrdering() { - int documentsCreated = 106; + int documentsCreated = 106; List<String> parts = Arrays.asList( "\"condition\":\"bla\"", "\"update\": \"id:unittest:testarray::whee\"", @@ -876,8 +1332,7 @@ public class JsonReaderTestCase { } } documents.append("]"); - InputStream rawDoc = new ByteArrayInputStream( - Utf8.toBytes(documents.toString())); + InputStream rawDoc = new ByteArrayInputStream(Utf8.toBytes(documents.toString())); JsonReader r = new JsonReader(types, rawDoc, parserFactory); @@ -886,7 +1341,6 @@ public class JsonReaderTestCase { checkSimpleArrayAdd(update); assertTrue(update.getCreateIfNonExistent()); assertEquals("bla", update.getCondition().getSelection()); - } assertNull(r.next()); @@ -895,13 +1349,18 @@ public class JsonReaderTestCase { @Test public void testCreateIfNonExistentInPut() { - JsonReader r = createReader(inputJson("[{", - " 'create':true,", - " 'fields': {", - " 'something': 'smoketest',", - " 'nalle': 'bamse' },", - " 'put': 'id:unittest:smoke::whee'", - "}]")); + JsonReader r = createReader(""" + [ + { + "create":true, + "fields": { + "something": "smoketest", + "nalle": "bamse" + }, + "put": "id:unittest:smoke::whee" + } + ] + """); var op = r.next(); var put = (DocumentPut) op; assertTrue(put.getCreateIfNonExistent()); @@ -909,23 +1368,32 @@ public class JsonReaderTestCase { @Test public void testCompleteFeedWithIdAfterFields() { - JsonReader r = createReader(inputJson("[", - "{", - " 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse' },", - " 'put': 'id:unittest:smoke::whee'", - "},", - "{", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person' ]}},", - " 'update': 'id:unittest:testarray::whee'", - "},", - "{ 'remove': 'id:unittest:smoke::whee' }]")); + JsonReader r = createReader(""" + [ + { + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + }, + "put": "id:unittest:smoke::whee" + }, + { + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + }, + "update": "id:unittest:testarray::whee" + }, + { + "remove": "id:unittest:smoke::whee" + } + ] + """); controlBasicFeed(r); } @@ -949,10 +1417,21 @@ public class JsonReaderTestCase { @Test public void testCompleteFeedWithEmptyDoc() { - JsonReader r = createReader(inputJson("[", - "{ 'put': 'id:unittest:smoke::whee', 'fields': {} },", - "{ 'update': 'id:unittest:testarray::whee', 'fields': {} },", - "{ 'remove': 'id:unittest:smoke::whee' }]")); + JsonReader r = createReader(""" + [ + { + "put": "id:unittest:smoke::whee", + "fields": {} + }, + { + "update": "id:unittest:testarray::whee", + "fields": {} + }, + { + "remove": "id:unittest:smoke::whee" + } + ] + """); DocumentOperation d = r.next(); Document doc = ((DocumentPut) d).getDocument(); @@ -994,45 +1473,53 @@ public class JsonReaderTestCase { @Test public void nonExistingFieldCausesException() throws IOException { - JsonReader r = createReader(inputJson( - "{ 'put': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'smething': 'smoketest',", - " 'nalle': 'bamse' }}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - - try { - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - fail(); - } catch (IllegalArgumentException e) { - assertTrue(e.getMessage().startsWith("No field 'smething' in the structure of type 'smoke'")); - } + Exception expected = assertThrows(IllegalArgumentException.class, + () -> docFromJson(""" + { + "put": "id:unittest:smoke::whee", + "fields": { + "smething": "smoketest", + "nalle": "bamse" + } + } + """)); + assertTrue(expected.getMessage().startsWith("No field 'smething' in the structure of type 'smoke'")); } @Test public void nonExistingFieldsCanBeIgnoredInPut() throws IOException { - JsonReader r = createReader(inputJson( - "{ ", - " 'put': 'id:unittest:smoke::doc1',", - " 'fields': {", - " 'nonexisting1': 'ignored value',", - " 'field1': 'value1',", - " 'nonexisting2': {", - " 'blocks':{", - " 'a':[2.0,3.0],", - " 'b':[4.0,5.0]", - " }", - " },", - " 'field2': 'value2',", - " 'nonexisting3': {", - " 'cells': [{'address': {'x': 'x1'}, 'value': 1.0}]", - " },", - " 'tensor1': {'cells': {'x1': 1.0}},", - " 'nonexisting4': 'ignored value'", - " }", - "}")); + JsonReader r = createReader(""" + { + "put": "id:unittest:smoke::doc1", + "fields": { + "nonexisting1": "ignored value", + "field1": "value1", + "nonexisting2": { + "blocks": { + "a": [2.0, 3.0], + "b": [4.0, 5.0] + } + }, + "field2": "value2", + "nonexisting3": { + "cells": [ + { + "address": { + "x": "x1" + }, + "value": 1.0 + } + ] + }, + "tensor1": { + "cells": { + "x1": 1.0 + } + }, + "nonexisting4": "ignored value" + } + } + """); DocumentParseInfo parseInfo = r.parseDocument().get(); DocumentType docType = r.readDocumentType(parseInfo.documentId); DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); @@ -1049,30 +1536,31 @@ public class JsonReaderTestCase { @Test public void nonExistingFieldsCanBeIgnoredInUpdate() throws IOException{ - JsonReader r = createReader(inputJson( - "{ ", - " 'update': 'id:unittest:smoke::doc1',", - " 'fields': {", - " 'nonexisting1': { 'assign': 'ignored value' },", - " 'field1': { 'assign': 'value1' },", - " 'nonexisting2': { " + - " 'assign': {", - " 'blocks': {", - " 'a':[2.0,3.0],", - " 'b':[4.0,5.0]", - " }", - " }", - " },", - " 'field2': { 'assign': 'value2' },", - " 'nonexisting3': {", - " 'assign' : {", - " 'cells': [{'address': {'x': 'x1'}, 'value': 1.0}]", - " }", - " },", - " 'tensor1': {'assign': { 'cells': {'x1': 1.0} } },", - " 'nonexisting4': { 'assign': 'ignored value' }", - " }", - "}")); + JsonReader r = createReader(""" + { + "update": "id:unittest:smoke::doc1", + "fields": { + "nonexisting1": { "assign": "ignored value" }, + "field1": { "assign": "value1" }, + "nonexisting2": { + "assign": { + "blocks": { + "a":[2.0,3.0], + "b":[4.0,5.0] + } + } + }, + "field2": { "assign": "value2" }, + "nonexisting3": { + "assign" : { + "cells": [{"address": {"x": "x1"}, "value": 1.0}] + } + }, + "tensor1": {"assign": { "cells": {"x1": 1.0} } }, + "nonexisting4": { "assign": "ignored value" } + } + } + """); DocumentParseInfo parseInfo = r.parseDocument().get(); DocumentType docType = r.readDocumentType(parseInfo.documentId); DocumentUpdate update = new DocumentUpdate(docType, parseInfo.documentId); @@ -1089,26 +1577,44 @@ public class JsonReaderTestCase { @Test public void feedWithBasicErrorTest() { - JsonReader r = createReader(inputJson("[", - " { 'put': 'id:test:smoke::0', 'fields': { 'something': 'foo' } },", - " { 'put': 'id:test:smoke::1', 'fields': { 'something': 'foo' } },", - " { 'put': 'id:test:smoke::2', 'fields': { 'something': 'foo' } },", - "]")); - try { - while (r.next() != null) ; - fail(); - } catch (RuntimeException e) { - assertTrue(e.getMessage().contains("JsonParseException")); - } + JsonReader r = createReader(""" + [ + { + "put": "id:test:smoke::0", + "fields": { + "something": "foo" + } + }, + { + "put": "id:test:smoke::1", + "fields": { + "something": "foo" + } + }, + { + "put": "id:test:smoke::2", + "fields": { + "something": "foo" + } + }, + ]"""); // Trailing comma in array ... + assertTrue(assertThrows(RuntimeException.class, + () -> { while (r.next() != null); }) + .getMessage().contains("JsonParseException")); } @Test public void idAsAliasForPutTest() throws IOException{ - JsonReader r = createReader(inputJson("{ 'id': 'id:unittest:smoke::doc1',", - " 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse' }}")); + JsonReader r = createReader(""" + { + "id": "id:unittest:smoke::doc1", + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + } + } + """); DocumentParseInfo parseInfo = r.parseDocument().get(); DocumentType docType = r.readDocumentType(parseInfo.documentId); DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); @@ -1138,147 +1644,146 @@ public class JsonReaderTestCase { @Test public void testFeedWithTestAndSetConditionOrderingOne() { - testFeedWithTestAndSetCondition( - inputJson("[", - " {", - " 'put': 'id:unittest:smoke::whee',", - " 'condition': 'smoke.something == \\'smoketest\\'',", - " 'fields': {", - " 'something': 'smoketest',", - " 'nalle': 'bamse'", - " }", - " },", - " {", - " 'update': 'id:unittest:testarray::whee',", - " 'condition': 'smoke.something == \\'smoketest\\'',", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person'", - " ]", - " }", - " }", - " },", - " {", - " 'remove': 'id:unittest:smoke::whee',", - " 'condition': 'smoke.something == \\'smoketest\\''", - " }", - "]" - )); + testFeedWithTestAndSetCondition(""" + [ + { + "put": "id:unittest:smoke::whee", + "condition": "smoke.something == \\"smoketest\\"", + "fields": { + "something": "smoketest", + "nalle": "bamse" + } + }, + { + "update": "id:unittest:testarray::whee", + "condition": "smoke.something == \\"smoketest\\"", + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + }, + { + "remove": "id:unittest:smoke::whee", + "condition": "smoke.something == \\"smoketest\\"" + } + ] + """); } @Test public void testFeedWithTestAndSetConditionOrderingTwo() { - testFeedWithTestAndSetCondition( - inputJson("[", - " {", - " 'condition': 'smoke.something == \\'smoketest\\'',", - " 'put': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': 'smoketest',", - " 'nalle': 'bamse'", - " }", - " },", - " {", - " 'condition': 'smoke.something == \\'smoketest\\'',", - " 'update': 'id:unittest:testarray::whee',", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person'", - " ]", - " }", - " }", - " },", - " {", - " 'condition': 'smoke.something == \\'smoketest\\'',", - " 'remove': 'id:unittest:smoke::whee'", - " }", - "]" - )); + testFeedWithTestAndSetCondition(""" + [ + { + "condition": "smoke.something == \\"smoketest\\"", + "put": "id:unittest:smoke::whee", + "fields": { + "something": "smoketest", + "nalle": "bamse" + } + }, + { + "condition": "smoke.something == \\"smoketest\\"", + "update": "id:unittest:testarray::whee", + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + }, + { + "condition": "smoke.something == \\"smoketest\\"", + "remove": "id:unittest:smoke::whee" + } + ] + """); } @Test public void testFeedWithTestAndSetConditionOrderingThree() { - testFeedWithTestAndSetCondition( - inputJson("[", - " {", - " 'put': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': 'smoketest',", - " 'nalle': 'bamse'", - " },", - " 'condition': 'smoke.something == \\'smoketest\\''", - " },", - " {", - " 'update': 'id:unittest:testarray::whee',", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person'", - " ]", - " }", - " },", - " 'condition': 'smoke.something == \\'smoketest\\''", - " },", - " {", - " 'remove': 'id:unittest:smoke::whee',", - " 'condition': 'smoke.something == \\'smoketest\\''", - " }", - "]" - )); + testFeedWithTestAndSetCondition(""" + [ + { + "put": "id:unittest:smoke::whee", + "fields": { + "something": "smoketest", + "nalle": "bamse" + }, + "condition": "smoke.something == \\"smoketest\\"" + }, + { + "update": "id:unittest:testarray::whee", + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + }, + "condition": "smoke.something == \\"smoketest\\"" + }, + { + "remove": "id:unittest:smoke::whee", + "condition": "smoke.something == \\"smoketest\\"" + } + ] + """); } @Test(expected = IllegalArgumentException.class) public void testInvalidFieldAfterFieldsFieldShouldFailParse() { - final String jsonData = inputJson( - "[", - " {", - " 'put': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': 'smoketest',", - " 'nalle': 'bamse'", - " },", - " 'bjarne': 'stroustrup'", - " }", - "]"); + String jsonData = """ + [ + { + "put": "id:unittest:smoke::whee", + "fields": { + "something": "smoketest", + "nalle": "bamse" + }, + "bjarne": "stroustrup" + } + ]"""; new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); } @Test(expected = IllegalArgumentException.class) public void testInvalidFieldBeforeFieldsFieldShouldFailParse() { - final String jsonData = inputJson( - "[", - " {", - " 'update': 'id:unittest:testarray::whee',", - " 'what is this': 'nothing to see here',", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person'", - " ]", - " }", - " }", - " }", - "]"); - + String jsonData = """ + [ + { + "update": "id:unittest:testarray::whee", + "what is this": "nothing to see here", + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + } + ]"""; new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); } @Test(expected = IllegalArgumentException.class) public void testInvalidFieldWithoutFieldsFieldShouldFailParse() { - String jsonData = inputJson( - "[", - " {", - " 'remove': 'id:unittest:smoke::whee',", - " 'what is love': 'baby, do not hurt me... much'", - " }", - "]"); + String jsonData = """ + [ + { + "remove": "id:unittest:smoke::whee", + "what is love": "baby, do not hurt me... much + } + ]"""; new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); } @@ -1286,19 +1791,19 @@ public class JsonReaderTestCase { @Test public void testMissingOperation() { try { - String jsonData = inputJson( - "[", - " {", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person'", - " ]", - " }", - " }", - " }", - "]"); + String jsonData = """ + [ + { + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + } + ]"""; new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); fail("Expected exception"); @@ -1311,12 +1816,12 @@ public class JsonReaderTestCase { @Test public void testMissingFieldsMapInPut() { try { - String jsonData = inputJson( - "[", - " {", - " 'put': 'id:unittest:smoke::whee'", - " }", - "]"); + String jsonData = """ + [ + { + "put": "id:unittest:smoke::whee" + } + ]"""; new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); fail("Expected exception"); @@ -1329,12 +1834,12 @@ public class JsonReaderTestCase { @Test public void testMissingFieldsMapInUpdate() { try { - String jsonData = inputJson( - "[", - " {", - " 'update': 'id:unittest:smoke::whee'", - " }", - "]"); + String jsonData = """ + [ + { + "update": "id:unittest:smoke::whee" + } + ]"""; new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); fail("Expected exception"); @@ -1345,20 +1850,20 @@ public class JsonReaderTestCase { } @Test - public void testNullValues() { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:testnull::doc1',", - " 'fields': {", - " 'intfield': null,", - " 'stringfield': null,", - " 'arrayfield': null,", - " 'weightedsetfield': null,", - " 'mapfield': null,", - " 'tensorfield': null", - " }", - "}")); - DocumentPut put = (DocumentPut) r.readSingleDocument(DocumentOperationType.PUT, - "id:unittest:testnull::doc1").operation(); - Document doc = put.getDocument(); + public void testNullValues() throws IOException { + Document doc = docFromJson(""" + { + "put": "id:unittest:testnull::doc1", + "fields": { + "intfield": null, + "stringfield": null, + "arrayfield": null, + "weightedsetfield": null, + "mapfield": null, + "tensorfield": null + } + } + """); assertFieldValueNull(doc, "intfield"); assertFieldValueNull(doc, "stringfield"); assertFieldValueNull(doc, "arrayfield"); @@ -1368,13 +1873,15 @@ public class JsonReaderTestCase { } @Test(expected=JsonReaderException.class) - public void testNullArrayElement() { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:testnull::doc1',", - " 'fields': {", - " 'arrayfield': [ null ]", - " }", - "}")); - r.readSingleDocument(DocumentOperationType.PUT, "id:unittest:testnull::doc1"); + public void testNullArrayElement() throws IOException { + docFromJson(""" + { + "put": "id:unittest:testnull::doc1", + "fields": { + "arrayfield": [ null ] + } + } + """); fail(); } @@ -1429,30 +1936,31 @@ public class JsonReaderTestCase { @Test public void testParsingOfSparseTensorWithCells() { Tensor tensor = assertSparseTensorField("{{x:a,y:b}:2.0,{x:c,y:b}:3.0}}", - createPutWithSparseTensor( - """ - { - "type": "tensor(x{},y{})", - "cells": [ - { "address": { "x": "a", "y": "b" }, "value": 2.0 }, - { "address": { "x": "c", "y": "b" }, "value": 3.0 } - ] - } - """)); + createPutWithSparseTensor(""" + { + "type": "tensor(x{},y{})", + "cells": [ + { "address": { "x": "a", "y": "b" }, "value": 2.0 }, + { "address": { "x": "c", "y": "b" }, "value": 3.0 } + ] + } + """)); assertTrue(tensor instanceof MappedTensor); // any functional instance is fine } @Test public void testParsingOfDenseTensorWithCells() { Tensor tensor = assertTensorField("{{x:0,y:0}:2.0,{x:1,y:0}:3.0}}", - createPutWithTensor(inputJson("{", - " 'cells': [", - " { 'address': { 'x': '0', 'y': '0' },", - " 'value': 2.0 },", - " { 'address': { 'x': '1', 'y': '0' },", - " 'value': 3.0 }", - " ]", - "}"), "dense_unbound_tensor"), "dense_unbound_tensor"); + createPutWithTensor(""" + { + "cells": [ + { "address": { "x": 0, "y": 0 }, "value": 2.0 }, + { "address": { "x": 1, "y": 0 }, "value": 3.0 } + ] + } + """, + "dense_unbound_tensor"), + "dense_unbound_tensor"); assertTrue(tensor instanceof IndexedTensor); // this matters for performance } @@ -1468,9 +1976,10 @@ public class JsonReaderTestCase { Tensor expected = builder.build(); Tensor tensor = assertTensorField(expected, - createPutWithTensor(inputJson("{", - " 'values': [2.0, 3.0, 4.0, 'inf', 6.0, 7.0]", - "}"), "dense_tensor"), "dense_tensor"); + createPutWithTensor(""" + { + "values": [2.0, 3.0, 4.0, "inf", 6.0, 7.0] + }""", "dense_tensor"), "dense_tensor"); assertTrue(tensor instanceof IndexedTensor); // this matters for performance } @@ -1485,9 +1994,10 @@ public class JsonReaderTestCase { builder.cell().label("x", 1).label("y", 2).value(7.0); Tensor expected = builder.build(); Tensor tensor = assertTensorField(expected, - createPutWithTensor(inputJson("{", - " 'values': \"020304050607\"", - "}"), "dense_int8_tensor"), "dense_int8_tensor"); + createPutWithTensor(""" + { + "values": "020304050607" + }""", "dense_int8_tensor"), "dense_int8_tensor"); assertTrue(tensor instanceof IndexedTensor); // this matters for performance } @@ -1501,10 +2011,14 @@ public class JsonReaderTestCase { builder.cell().label("x", "bar").label("y", 1).value(6.0); builder.cell().label("x", "bar").label("y", 2).value(7.0); Tensor expected = builder.build(); - String mixedJson = "{\"blocks\":[" + - "{\"address\":{\"x\":\"foo\"},\"values\":\"400040404080\"}," + - "{\"address\":{\"x\":\"bar\"},\"values\":\"40A040C040E0\"}" + - "]}"; + String mixedJson = """ + { + "blocks":[ + {"address":{"x":"foo"},"values":"400040404080"}, + {"address":{"x":"bar"},"values":"40A040C040E0"} + ] + } + """; var put = createPutWithTensor(inputJson(mixedJson), "mixed_bfloat16_tensor"); Tensor tensor = assertTensorField(expected, put, "mixed_bfloat16_tensor"); } @@ -1587,10 +2101,14 @@ public class JsonReaderTestCase { builder.cell().label("x", 1).label("y", 2).value(7.0); Tensor expected = builder.build(); - String mixedJson = "{\"blocks\":{" + - "\"0\":[2.0,3.0,4.0]," + - "\"1\":[5.0,6.0,7.0]" + - "}}"; + String mixedJson = """ + { + "blocks":{ + "0":[2.0,3.0,4.0], + "1":[5.0,6.0,7.0] + } + } + """; Tensor tensor = assertTensorField(expected, createPutWithTensor(inputJson(mixedJson), "mixed_tensor"), "mixed_tensor"); assertTrue(tensor instanceof MixedTensor); // this matters for performance @@ -1599,12 +2117,14 @@ public class JsonReaderTestCase { @Test public void testParsingOfTensorWithSingleCellInDifferentJsonOrder() { assertSparseTensorField("{{x:a,y:b}:2.0}", - createPutWithSparseTensor(inputJson("{", - " 'cells': [", - " { 'value': 2.0,", - " 'address': { 'x': 'a', 'y': 'b' } }", - " ]", - "}"))); + createPutWithSparseTensor(""" + { + "cells": [ + { "value": 2.0, + "address": { "x": "a", "y": "b" } } + ] + } + """)); } @Test @@ -1634,91 +2154,119 @@ public class JsonReaderTestCase { @Test public void testAssignUpdateOfTensorWithCells() { assertTensorAssignUpdateSparseField("{{x:a,y:b}:2.0,{x:c,y:b}:3.0}}", - createAssignUpdateWithSparseTensor(inputJson("{", - " 'cells': [", - " { 'address': { 'x': 'a', 'y': 'b' },", - " 'value': 2.0 },", - " { 'address': { 'x': 'c', 'y': 'b' },", - " 'value': 3.0 }", - " ]", - "}"))); + createAssignUpdateWithSparseTensor(""" + { + "cells": [ + { "address": { "x": "a", "y": "b" }, + "value": 2.0 }, + { "address": { "x": "c", "y": "b" }, + "value": 3.0 } + ] + } + """)); } @Test public void testAssignUpdateOfTensorDenseShortForm() { assertTensorAssignUpdateDenseField("tensor(x[2],y[3]):[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]", - createAssignUpdateWithTensor(inputJson("{", - " 'values': [1,2,3,4,5,6]", - "}"), - "dense_tensor")); + createAssignUpdateWithTensor(""" + { + "values": [1,2,3,4,5,6] + } + """, + "dense_tensor")); } @Test public void tensor_modify_update_with_replace_operation() { assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.REPLACE, "sparse_tensor", - inputJson("{", - " 'operation': 'replace',", - " 'cells': [", - " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}")); + """ + { + "operation": "replace", + "cells": [ + { "address": { "x": "a", "y": "b" }, "value": 2.0 } + ] + }"""); } @Test public void tensor_modify_update_with_add_operation() { assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.ADD, "sparse_tensor", - inputJson("{", - " 'operation': 'add',", - " 'cells': [", - " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}")); + """ + { + "operation": "add", + "cells": [ + { "address": { "x": "a", "y": "b" }, "value": 2.0 } + ] + }"""); } @Test public void tensor_modify_update_with_multiply_operation() { assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.MULTIPLY, "sparse_tensor", - inputJson("{", - " 'operation': 'multiply',", - " 'cells': [", - " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}")); + """ + { + "operation": "multiply", + "cells": [ + { "address": { "x": "a", "y": "b" }, "value": 2.0 } + ] + }"""); } @Test public void tensor_modify_update_with_create_non_existing_cells_true() { assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.ADD, true, "sparse_tensor", - inputJson("{", - " 'operation': 'add',", - " 'create': true,", - " 'cells': [", - " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}")); + """ + { + "operation": "add", + "create": true, + "cells": [ + { "address": { "x": "a", "y": "b" }, "value": 2.0 } + ] + }"""); } @Test public void tensor_modify_update_with_create_non_existing_cells_false() { assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.ADD, false, "sparse_tensor", - inputJson("{", - " 'operation': 'add',", - " 'create': false,", - " 'cells': [", - " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}")); + """ + { + "operation": "add", + "create": false, + "cells": [ + { "address": { "x": "a", "y": "b" }, "value": 2.0 } + ] + }"""); } @Test public void tensor_modify_update_treats_the_input_tensor_as_sparse() { // Note that the type of the tensor in the modify update is sparse (it only has mapped dimensions). assertTensorModifyUpdate("tensor(x{},y{}):{{x:0,y:0}:2.0, {x:1,y:2}:3.0}", - TensorModifyUpdate.Operation.REPLACE, "dense_tensor", - inputJson("{", - " 'operation': 'replace',", - " 'cells': [", - " { 'address': { 'x': '0', 'y': '0' }, 'value': 2.0 },", - " { 'address': { 'x': '1', 'y': '2' }, 'value': 3.0 } ]}")); + TensorModifyUpdate.Operation.REPLACE, "dense_tensor", + """ + { + "operation": "replace", + "cells": [ + { "address": { "x": "0", "y": "0" }, "value": 2.0 }, + { "address": { "x": "1", "y": "2" }, "value": 3.0 } + ] + }"""); } @Test public void tensor_modify_update_on_non_tensor_field_throws() { try { - JsonReader reader = createReader(inputJson("{ 'update': 'id:unittest:smoke::doc1',", - " 'fields': {", - " 'something': {", - " 'modify': {} }}}")); + JsonReader reader = createReader(""" + { + "update": "id:unittest:smoke::doc1", + "fields": { + "something": { + "modify": {} + } + } + } + """); reader.readSingleDocument(DocumentOperationType.UPDATE, "id:unittest:smoke::doc1"); fail("Expected exception"); } @@ -1732,95 +2280,125 @@ public class JsonReaderTestCase { public void tensor_modify_update_on_dense_unbound_tensor_throws() { illegalTensorModifyUpdate("Error in 'dense_unbound_tensor': A modify update cannot be applied to tensor types with indexed unbound dimensions. Field 'dense_unbound_tensor' has unsupported tensor type 'tensor(x[],y[])'", "dense_unbound_tensor", - "{", - " 'operation': 'replace',", - " 'cells': [", - " { 'address': { 'x': '0', 'y': '0' }, 'value': 2.0 } ]}"); + """ + { + "operation": "replace", + "cells": [ + { "address": { "x": "0", "y": "0" }, "value": 2.0 } + ] + }"""); } @Test public void tensor_modify_update_on_sparse_tensor_with_single_dimension_short_form() { - assertTensorModifyUpdate("{{x:a}:2.0, {x:c}: 3.0}", TensorModifyUpdate.Operation.REPLACE, "sparse_single_dimension_tensor", - inputJson("{", - " 'operation': 'replace',", - " 'cells': {", - " 'a': 2.0,", - " 'c': 3.0 }}")); + assertTensorModifyUpdate("{{x:a}:2.0, {x:c}: 3.0}", TensorModifyUpdate.Operation.REPLACE, "sparse_single_dimension_tensor", + """ + { + "operation": "replace", + "cells": { + "a": 2.0, + "c": 3.0 + } + }"""); } @Test public void tensor_modify_update_with_replace_operation_mixed() { assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", - inputJson("{", - " 'operation': 'replace',", - " 'cells': [", - " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}")); + """ + { + "operation": "replace", + "cells": [ + { "address": { "x": "a", "y": "0" }, "value": 2.0 } + ] + }"""); } @Test public void tensor_modify_update_with_replace_operation_mixed_block_short_form_array() { assertTensorModifyUpdate("{{x:a,y:0}:1,{x:a,y:1}:2,{x:a,y:2}:3}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", - inputJson("{", - " 'operation': 'replace',", - " 'blocks': [", - " { 'address': { 'x': 'a' }, 'values': [1,2,3] } ]}")); + """ + { + "operation": "replace", + "blocks": [ + { "address": { "x": "a" }, "values": [1,2,3] } + ] + }"""); } @Test public void tensor_modify_update_with_replace_operation_mixed_block_short_form_must_specify_full_subspace() { illegalTensorModifyUpdate("Error in 'mixed_tensor': At {x:a}: Expected 3 values, but got 2", - "mixed_tensor", - inputJson("{", - " 'operation': 'replace',", - " 'blocks': {", - " 'a': [2,3] } }")); + "mixed_tensor", + """ + { + "operation": "replace", + "blocks": { + "a": [2,3] + } + }"""); } @Test public void tensor_modify_update_with_replace_operation_mixed_block_short_form_map() { assertTensorModifyUpdate("{{x:a,y:0}:1,{x:a,y:1}:2,{x:a,y:2}:3}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", - inputJson("{", - " 'operation': 'replace',", - " 'blocks': {", - " 'a': [1,2,3] } }")); + """ + { + "operation": "replace", + "blocks": { + "a": [1,2,3] + } + }"""); } @Test public void tensor_modify_update_with_add_operation_mixed() { assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.ADD, "mixed_tensor", - inputJson("{", - " 'operation': 'add',", - " 'cells': [", - " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}")); + """ + { + "operation": "add", + "cells": [ + { "address": { "x": "a", "y": "0" }, "value": 2.0 } + ] + }"""); } @Test public void tensor_modify_update_with_multiply_operation_mixed() { assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.MULTIPLY, "mixed_tensor", - inputJson("{", - " 'operation': 'multiply',", - " 'cells': [", - " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}")); + """ + { + "operation": "multiply", + "cells": [ + { "address": { "x": "a", "y": "0" }, "value": 2.0 } + ] + }"""); } @Test public void tensor_modify_update_with_out_of_bound_cells_throws() { illegalTensorModifyUpdate("Error in 'dense_tensor': Dimension 'y' has label '3' but type is tensor(x[2],y[3])", "dense_tensor", - "{", - " 'operation': 'replace',", - " 'cells': [", - " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"); + """ + { + "operation": "replace", + "cells": [ + { "address": { "x": "0", "y": "3" }, "value": 2.0 } + ] + }"""); } @Test public void tensor_modify_update_with_out_of_bound_cells_throws_mixed() { illegalTensorModifyUpdate("Error in 'mixed_tensor': Dimension 'y' has label '3' but type is tensor(x{},y[3])", "mixed_tensor", - "{", - " 'operation': 'replace',", - " 'cells': [", - " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"); + """ + { + "operation": "replace", + "cells": [ + { "address": { "x": "0", "y": "3" }, "value": 2.0 } + ] + }"""); } @@ -1828,87 +2406,113 @@ public class JsonReaderTestCase { public void tensor_modify_update_with_unknown_operation_throws() { illegalTensorModifyUpdate("Error in 'sparse_tensor': Unknown operation 'unknown' in modify update for field 'sparse_tensor'", "sparse_tensor", - "{", - " 'operation': 'unknown',", - " 'cells': [", - " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}"); + """ + { + "operation": "unknown", + "cells": [ + { "address": { "x": "a", "y": "b" }, "value": 2.0 } + ] + }"""); } @Test public void tensor_modify_update_without_operation_throws() { illegalTensorModifyUpdate("Error in 'sparse_tensor': Modify update for field 'sparse_tensor' does not contain an operation", "sparse_tensor", - "{", - " 'cells': [] }"); + """ + { + "cells": [] + }"""); } @Test public void tensor_modify_update_without_cells_throws() { illegalTensorModifyUpdate("Error in 'sparse_tensor': Modify update for field 'sparse_tensor' does not contain tensor cells", "sparse_tensor", - "{", - " 'operation': 'replace' }"); + """ + { + "operation": "replace" + }"""); } @Test public void tensor_modify_update_with_unknown_content_throws() { illegalTensorModifyUpdate("Error in 'sparse_tensor': Unknown JSON string 'unknown' in modify update for field 'sparse_tensor'", "sparse_tensor", - "{", - " 'unknown': 'here' }"); + """ + { + "unknown": "here" + }"""); } @Test public void tensor_add_update_on_sparse_tensor() { assertTensorAddUpdate("{{x:a,y:b}:2.0, {x:c,y:d}: 3.0}", "sparse_tensor", - inputJson("{", - " 'cells': [", - " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 },", - " { 'address': { 'x': 'c', 'y': 'd' }, 'value': 3.0 } ]}")); + """ + { + "cells": [ + { "address": { "x": "a", "y": "b" }, "value": 2.0 }, + { "address": { "x": "c", "y": "d" }, "value": 3.0 } + ] + }"""); } @Test public void tensor_add_update_on_sparse_tensor_with_single_dimension_short_form() { assertTensorAddUpdate("{{x:a}:2.0, {x:c}: 3.0}", "sparse_single_dimension_tensor", - inputJson("{", - " 'cells': {", - " 'a': 2.0,", - " 'c': 3.0 }}")); + """ + { + "cells": { + "a": 2.0, + "c": 3.0 + } + }"""); } @Test public void tensor_add_update_on_mixed_tensor() { assertTensorAddUpdate("{{x:a,y:0}:2.0, {x:a,y:1}:3.0, {x:a,y:2}:0.0}", "mixed_tensor", - inputJson("{", - " 'cells': [", - " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },", - " { 'address': { 'x': 'a', 'y': '1' }, 'value': 3.0 } ]}")); + """ + { + "cells": [ + { "address": { "x": "a", "y": "0" }, "value": 2.0 }, + { "address": { "x": "a", "y": "1" }, "value": 3.0 } + ] + }"""); } @Test public void tensor_add_update_on_mixed_with_out_of_bound_dense_cells_throws() { illegalTensorAddUpdate("Error in 'mixed_tensor': Index 3 out of bounds for length 3", "mixed_tensor", - "{", - " 'cells': [", - " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"); + """ + { + "cells": [ + { "address": { "x": "0", "y": "3" }, "value": 2.0 } + ] + }"""); } @Test public void tensor_add_update_on_dense_tensor_throws() { illegalTensorAddUpdate("Error in 'dense_tensor': An add update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'", "dense_tensor", - "{", - " 'cells': [] }"); + """ + { + "cells": [ ] + }"""); } @Test public void tensor_add_update_on_not_fully_specified_cell_throws() { illegalTensorAddUpdate("Error in 'sparse_tensor': Missing a label for dimension 'y' for tensor(x{},y{})", "sparse_tensor", - "{", - " 'cells': [", - " { 'address': { 'x': 'a' }, 'value': 2.0 } ]}"); + """ + { + "cells": [ + { "address": { "x": "a" }, "value": 2.0 } + ] + }"""); } @Test @@ -1924,146 +2528,176 @@ public class JsonReaderTestCase { @Test public void tensor_remove_update_on_sparse_tensor() { assertTensorRemoveUpdate("{{x:a,y:b}:1.0,{x:c,y:d}:1.0}", "sparse_tensor", - inputJson("{", - " 'addresses': [", - " { 'x': 'a', 'y': 'b' },", - " { 'x': 'c', 'y': 'd' } ]}")); + """ + { + "addresses": [ + { "x": "a", "y": "b" }, + { "x": "c", "y": "d" } + ] + }"""); } @Test public void tensor_remove_update_on_mixed_tensor() { assertTensorRemoveUpdate("{{x:1}:1.0,{x:2}:1.0}", "mixed_tensor", - inputJson("{", - " 'addresses': [", - " { 'x': '1' },", - " { 'x': '2' } ]}")); + """ + { + "addresses": [ + { "x": "1" }, + { "x": "2" } + ] + }"""); } @Test public void tensor_remove_update_on_sparse_tensor_with_not_fully_specified_address() { assertTensorRemoveUpdate("{{y:b}:1.0,{y:d}:1.0}", "sparse_tensor", - inputJson("{", - " 'addresses': [", - " { 'y': 'b' },", - " { 'y': 'd' } ]}")); + """ + { + "addresses": [ + { "y": "b" }, + { "y": "d" } + ] + }"""); } @Test public void tensor_remove_update_on_mixed_tensor_with_not_fully_specified_address() { assertTensorRemoveUpdate("{{x:1,z:a}:1.0,{x:2,z:b}:1.0}", "mixed_tensor_adv", - inputJson("{", - " 'addresses': [", - " { 'x': '1', 'z': 'a' },", - " { 'x': '2', 'z': 'b' } ]}")); + """ + { + "addresses": [ + { "x": "1", "z": "a" }, + { "x": "2", "z": "b" } + ] + }"""); } @Test public void tensor_remove_update_on_mixed_tensor_with_dense_addresses_throws() { illegalTensorRemoveUpdate("Error in 'mixed_tensor': Indexed dimension address 'y' should not be specified in remove update", "mixed_tensor", - "{", - " 'addresses': [", - " { 'x': '1', 'y': '0' },", - " { 'x': '2', 'y': '0' } ]}"); + """ + { + "addresses": [ + { "x": "1", "y": "0" }, + { "x": "2", "y": "0" } + ] + }"""); } @Test public void tensor_remove_update_on_dense_tensor_throws() { illegalTensorRemoveUpdate("Error in 'dense_tensor': A remove update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'", "dense_tensor", - "{", - " 'addresses': [] }"); + """ + { + "addresses": [] + }"""); } @Test public void tensor_remove_update_with_stray_dimension_throws() { illegalTensorRemoveUpdate("Error in 'sparse_tensor': tensor(x{},y{}) does not contain dimension 'foo'", - "sparse_tensor", - "{", - " 'addresses': [", - " { 'x': 'a', 'foo': 'b' } ]}"); + "sparse_tensor", + """ + { + "addresses": [ + { "x": "a", "foo": "b" } + ] + }"""); illegalTensorRemoveUpdate("Error in 'sparse_tensor': tensor(x{}) does not contain dimension 'foo'", - "sparse_tensor", - "{", - " 'addresses': [", - " { 'x': 'c' },", - " { 'x': 'a', 'foo': 'b' } ]}"); + "sparse_tensor", + """ + { + "addresses": [ + { "x": "c" }, + { "x": "a", "foo": "b" } + ] + }"""); } @Test public void tensor_remove_update_without_cells_throws() { illegalTensorRemoveUpdate("Error in 'sparse_tensor': Remove update for field 'sparse_tensor' does not contain tensor addresses", "sparse_tensor", - "{'addresses': [] }"); + """ + { + "addresses": [] + }"""); illegalTensorRemoveUpdate("Error in 'mixed_tensor': Remove update for field 'mixed_tensor' does not contain tensor addresses", "mixed_tensor", - "{'addresses': [] }"); + """ + { + "addresses": [] + }"""); } @Test public void require_that_parser_propagates_datatype_parser_errors_predicate() { assertParserErrorMatches( "Error in document 'id:unittest:testpredicate::0' - could not parse field 'boolean' of type 'predicate': " + - "line 1:10 no viable alternative at character '>'", - - "[", - " {", - " 'fields': {", - " 'boolean': 'timestamp > 9000'", - " },", - " 'put': 'id:unittest:testpredicate::0'", - " }", - "]" - ); + "line 1:10 no viable alternative at character '>'", + """ + [ + { + "fields": { + "boolean": "timestamp > 9000" + }, + "put": "id:unittest:testpredicate::0" + } + ] + """); } @Test public void require_that_parser_propagates_datatype_parser_errors_string_as_int() { assertParserErrorMatches( "Error in document 'id:unittest:testint::0' - could not parse field 'integerfield' of type 'int': " + - "For input string: \" 1\"", - - "[", - " {", - " 'fields': {", - " 'integerfield': ' 1'", - " },", - " 'put': 'id:unittest:testint::0'", - " }", - "]" - ); + "For input string: \" 1\"", + """ + [ + { + "fields": { + "integerfield": " 1" + }, + "put": "id:unittest:testint::0" + } + ] + """); } @Test public void require_that_parser_propagates_datatype_parser_errors_overflowing_int() { assertParserErrorMatches( "Error in document 'id:unittest:testint::0' - could not parse field 'integerfield' of type 'int': " + - "For input string: \"281474976710656\"", - - "[", - " {", - " 'fields': {", - " 'integerfield': 281474976710656", - " },", - " 'put': 'id:unittest:testint::0'", - " }", - "]" - ); + "For input string: \"281474976710656\"", + """ + [ + { + "fields": { + "integerfield": 281474976710656 + }, + "put": "id:unittest:testint::0" + } + ] + """); } @Test public void requireThatUnknownDocTypeThrowsIllegalArgumentException() { - final String jsonData = inputJson( - "[", - " {", - " 'put': 'id:ns:walrus::walrus1',", - " 'fields': {", - " 'aField': 42", - " }", - " }", - "]"); + String jsonData = """ + [ + { + "put": "id:ns:walrus::walrus1", + "fields": { + "aField": 42 + } + } + ] + """; try { new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); fail(); @@ -2113,30 +2747,40 @@ public class JsonReaderTestCase { return createPutWithTensor(inputTensor, "sparse_tensor"); } private DocumentPut createPutWithTensor(String inputTensor, String tensorFieldName) { - JsonReader reader = createReader(inputJson("[", - "{ 'put': '" + TENSOR_DOC_ID + "',", - " 'fields': {", - " '" + tensorFieldName + "': " + inputTensor + " }}]")); - return (DocumentPut) reader.next(); + JsonReader streaming = createReader(""" + { + "fields": { + "%s": %s + } + } + """.formatted(tensorFieldName, inputTensor)); + DocumentPut lazyParsed = (DocumentPut) streaming.readSingleDocumentStreaming(DocumentOperationType.PUT, TENSOR_DOC_ID).operation(); + JsonReader reader = createReader(""" + [ + { + "put": "%s", + "fields": { + "%s": %s + } + } + ]""".formatted(TENSOR_DOC_ID, tensorFieldName, inputTensor)); + DocumentPut bufferParsed = (DocumentPut) reader.next(); + assertEquals(lazyParsed, bufferParsed); + return bufferParsed; } private DocumentUpdate createAssignUpdateWithSparseTensor(String inputTensor) { return createAssignUpdateWithTensor(inputTensor, "sparse_tensor"); } private DocumentUpdate createAssignUpdateWithTensor(String inputTensor, String tensorFieldName) { - JsonReader reader = createReader(inputJson("[", - "{ 'update': '" + TENSOR_DOC_ID + "',", - " 'fields': {", - " '" + tensorFieldName + "': {", - " 'assign': " + (inputTensor != null ? inputTensor : "null") + " } } } ]")); - return (DocumentUpdate) reader.next(); + return createTensorUpdate("assign", inputTensor, tensorFieldName); } private static Tensor assertSparseTensorField(String expectedTensor, DocumentPut put) { return assertTensorField(expectedTensor, put, "sparse_tensor"); } private Tensor assertTensorField(String expectedTensor, String fieldName, String inputJson) { - return assertTensorField(expectedTensor, createPutWithTensor(inputJson, fieldName), fieldName); + return assertTensorField(expectedTensor, createPutWithTensor(inputJson(inputJson), fieldName), fieldName); } private static Tensor assertTensorField(String expectedTensor, DocumentPut put, String tensorFieldName) { return assertTensorField(Tensor.from(expectedTensor), put, tensorFieldName); @@ -2209,12 +2853,29 @@ public class JsonReaderTestCase { } private DocumentUpdate createTensorUpdate(String operation, String tensorJson, String tensorFieldName) { - JsonReader reader = createReader(inputJson("[", - "{ 'update': '" + TENSOR_DOC_ID + "',", - " 'fields': {", - " '" + tensorFieldName + "': {", - " '" + operation + "': " + tensorJson + " }}}]")); - return (DocumentUpdate) reader.next(); + JsonReader streaming = createReader(""" + { + "fields": { + "%s": { + "%s": %s + } + } + }""".formatted(tensorFieldName, operation, tensorJson)); + DocumentUpdate lazyParsed = (DocumentUpdate) streaming.readSingleDocumentStreaming(DocumentOperationType.UPDATE, TENSOR_DOC_ID).operation(); + JsonReader reader = createReader(""" + [ + { + "update": "%s", + "fields": { + "%s": { + "%s": %s + } + } + } + ]""".formatted(TENSOR_DOC_ID, tensorFieldName, operation, tensorJson)); + DocumentUpdate bufferParsed = (DocumentUpdate) reader.next(); + assertEquals(lazyParsed, bufferParsed); + return bufferParsed; } private void assertTensorAddUpdate(String expectedTensor, String tensorFieldName, String tensorJson) { diff --git a/document/src/test/java/com/yahoo/document/json/LazyTokenBufferTest.java b/document/src/test/java/com/yahoo/document/json/LazyTokenBufferTest.java new file mode 100644 index 00000000000..3ed2ed531c3 --- /dev/null +++ b/document/src/test/java/com/yahoo/document/json/LazyTokenBufferTest.java @@ -0,0 +1,132 @@ +package com.yahoo.document.json; + +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; +import com.yahoo.document.json.TokenBuffer.Token; +import org.junit.Test; + +import java.io.IOException; +import java.util.function.Supplier; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** + * @author jonmv + */ +public class LazyTokenBufferTest { + + @Test + public void testBuffer() throws IOException { + String json = """ + { + "fields": { + "foo": "bar", + "baz": [1, 2, 3], + "quu": { "qux": null } + } + }"""; + JsonParser parser = new JsonFactory().createParser(json); + parser.nextValue(); + parser.nextValue(); + assertEquals(JsonToken.START_OBJECT, parser.currentToken()); + assertEquals("fields", parser.currentName()); + + // Peeking through the buffer doesn't change nesting. + LazyTokenBuffer buffer = new LazyTokenBuffer(parser); + assertEquals(JsonToken.START_OBJECT, buffer.current()); + assertEquals("fields", buffer.currentName()); + assertEquals(1, buffer.nesting()); + + Supplier<Token> lookahead = buffer.lookahead(); + Token peek = lookahead.get(); + assertEquals(JsonToken.VALUE_STRING, peek.token); + assertEquals("foo", peek.name); + assertEquals("bar", peek.text); + assertEquals(1, buffer.nesting()); + + peek = lookahead.get(); + assertEquals(JsonToken.START_ARRAY, peek.token); + assertEquals("baz", peek.name); + assertEquals(1, buffer.nesting()); + + peek = lookahead.get(); + assertEquals(JsonToken.VALUE_NUMBER_INT, peek.token); + assertEquals("1", peek.text); + + peek = lookahead.get(); + assertEquals(JsonToken.VALUE_NUMBER_INT, peek.token); + assertEquals("2", peek.text); + + peek = lookahead.get(); + assertEquals(JsonToken.VALUE_NUMBER_INT, peek.token); + assertEquals("3", peek.text); + + peek = lookahead.get(); + assertEquals(JsonToken.END_ARRAY, peek.token); + assertEquals(1, buffer.nesting()); + + peek = lookahead.get(); + assertEquals(JsonToken.START_OBJECT, peek.token); + assertEquals("quu", peek.name); + assertEquals(1, buffer.nesting()); + + peek = lookahead.get(); + assertEquals(JsonToken.VALUE_NULL, peek.token); + assertEquals("qux", peek.name); + + peek = lookahead.get(); + assertEquals(JsonToken.END_OBJECT, peek.token); + assertEquals(1, buffer.nesting()); + + peek = lookahead.get(); + assertEquals(JsonToken.END_OBJECT, peek.token); + assertEquals(1, buffer.nesting()); + + peek = lookahead.get(); + assertNull(peek); + + // Parser is now at the end. + assertEquals(JsonToken.END_OBJECT, parser.nextToken()); + assertNull(parser.nextToken()); + + // Repeat iterating through the buffer, this time advancing it, and see that nesting changes. + assertEquals(JsonToken.VALUE_STRING, buffer.next()); + assertEquals("foo", buffer.currentName()); + assertEquals("bar", buffer.currentText()); + assertEquals(1, buffer.nesting()); + + assertEquals(JsonToken.START_ARRAY, buffer.next()); + assertEquals("baz", buffer.currentName()); + assertEquals(2, buffer.nesting()); + + assertEquals(JsonToken.VALUE_NUMBER_INT, buffer.next()); + assertEquals("1", buffer.currentText()); + + assertEquals(JsonToken.VALUE_NUMBER_INT, buffer.next()); + assertEquals("2", buffer.currentText()); + + assertEquals(JsonToken.VALUE_NUMBER_INT, buffer.next()); + assertEquals("3", buffer.currentText()); + + assertEquals(JsonToken.END_ARRAY, buffer.next()); + assertEquals(1, buffer.nesting()); + + assertEquals(JsonToken.START_OBJECT, buffer.next()); + assertEquals("quu", buffer.currentName()); + assertEquals(2, buffer.nesting()); + + assertEquals(JsonToken.VALUE_NULL, buffer.next()); + assertEquals("qux", buffer.currentName()); + + assertEquals(JsonToken.END_OBJECT, buffer.next()); + assertEquals(1, buffer.nesting()); + + assertEquals(JsonToken.END_OBJECT, buffer.next()); + assertEquals(0, buffer.nesting()); + + assertNull(buffer.next()); + } + +} diff --git a/document/src/vespa/document/repo/configbuilder.cpp b/document/src/vespa/document/repo/configbuilder.cpp index 5f40bde1966..cf563c5c783 100644 --- a/document/src/vespa/document/repo/configbuilder.cpp +++ b/document/src/vespa/document/repo/configbuilder.cpp @@ -19,6 +19,7 @@ DatatypeConfig::DatatypeConfig() { } DatatypeConfig::DatatypeConfig(const DatatypeConfig&) = default; +DatatypeConfig::~DatatypeConfig() = default; DatatypeConfig& DatatypeConfig::operator=(const DatatypeConfig&) = default; void DatatypeConfig::addNestedType(const TypeOrId &t) { diff --git a/document/src/vespa/document/repo/configbuilder.h b/document/src/vespa/document/repo/configbuilder.h index 4ef17425c1b..61924b2b41a 100644 --- a/document/src/vespa/document/repo/configbuilder.h +++ b/document/src/vespa/document/repo/configbuilder.h @@ -17,8 +17,8 @@ struct DatatypeConfig : DocumenttypesConfig::Documenttype::Datatype { std::vector<DatatypeConfig> nested_types; DatatypeConfig(); - DatatypeConfig(const DatatypeConfig&); + ~DatatypeConfig(); DatatypeConfig& operator=(const DatatypeConfig&); DatatypeConfig &setId(int32_t i) { id = i; return *this; } diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java index 2854ef8836a..72f0fb977d5 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java @@ -74,7 +74,17 @@ public class FileDownloader implements AutoCloseable { } } - Future<Optional<File>> getFutureFile(FileReferenceDownload fileReferenceDownload) { + /** Returns a future that times out if download takes too long, and return empty on unsuccessful download. */ + public Future<Optional<File>> getFutureFileOrTimeout(FileReferenceDownload fileReferenceDownload) { + return getFutureFile(fileReferenceDownload) + .orTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS) + .exceptionally(thrown -> { + fileReferenceDownloader.failedDownloading(fileReferenceDownload.fileReference()); + return Optional.empty(); + }); + } + + CompletableFuture<Optional<File>> getFutureFile(FileReferenceDownload fileReferenceDownload) { FileReference fileReference = fileReferenceDownload.fileReference(); Optional<File> file = getFileFromFileSystem(fileReference); @@ -135,7 +145,7 @@ public class FileDownloader implements AutoCloseable { } /** Start downloading, the future returned will be complete()d by receiving method in {@link FileReceiver} */ - private synchronized Future<Optional<File>> startDownload(FileReferenceDownload fileReferenceDownload) { + private synchronized CompletableFuture<Optional<File>> startDownload(FileReferenceDownload fileReferenceDownload) { return fileReferenceDownloader.startDownload(fileReferenceDownload); } diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java index 450801ce530..5ad197e8633 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java @@ -15,6 +15,7 @@ import java.time.Instant; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -67,7 +68,7 @@ public class FileReferenceDownloader { int retryCount = 0; Connection connection = connectionPool.getCurrent(); do { - backoff(retryCount); + backoff(retryCount, end); if (FileDownloader.fileReferenceExists(fileReference, downloadDirectory)) return; @@ -79,24 +80,26 @@ public class FileReferenceDownloader { // exist on just one config server, and which one could be different for each file reference), so // switch to a new connection for every retry connection = connectionPool.switchConnection(connection); - } while (retryCount < 5 || Instant.now().isAfter(end)); + } while (Instant.now().isBefore(end)); fileReferenceDownload.future().completeExceptionally(new RuntimeException("Failed getting " + fileReference)); downloads.remove(fileReference); } - private void backoff(int retryCount) { + private void backoff(int retryCount, Instant end) { if (retryCount > 0) { try { - long sleepTime = Math.min(120_000, (long) (Math.pow(2, retryCount)) * sleepBetweenRetries.toMillis()); - Thread.sleep(sleepTime); + long sleepTime = Math.min(120_000, + Math.min((long) (Math.pow(2, retryCount)) * sleepBetweenRetries.toMillis(), + Duration.between(Instant.now(), end).toMillis())); + if (sleepTime > 0) Thread.sleep(sleepTime); } catch (InterruptedException e) { /* ignored */ } } } - Future<Optional<File>> startDownload(FileReferenceDownload fileReferenceDownload) { + CompletableFuture<Optional<File>> startDownload(FileReferenceDownload fileReferenceDownload) { FileReference fileReference = fileReferenceDownload.fileReference(); Optional<FileReferenceDownload> inProgress = downloads.get(fileReference); if (inProgress.isPresent()) return inProgress.get().future(); diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index cd264fa9f7c..bd7ed3369eb 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -211,7 +211,7 @@ public class Flags { // TODO: Move to a permanent flag public static final UnboundListFlag<String> ALLOWED_ATHENZ_PROXY_IDENTITIES = defineListFlag( "allowed-athenz-proxy-identities", List.of(), String.class, - List.of("bjorncs", "tokle"), "2021-02-10", "2024-02-01", + List.of("bjorncs", "tokle"), "2021-02-10", "2024-04-01", "Allowed Athenz proxy identities", "takes effect at redeployment"); @@ -272,7 +272,7 @@ public class Flags { public static final UnboundBooleanFlag ENABLE_PROXY_PROTOCOL_MIXED_MODE = defineFeatureFlag( "enable-proxy-protocol-mixed-mode", true, - List.of("tokle"), "2022-05-09", "2024-02-01", + List.of("tokle"), "2022-05-09", "2024-04-01", "Enable or disable proxy protocol mixed mode", "Takes effect on redeployment", INSTANCE_ID); @@ -316,7 +316,7 @@ public class Flags { public static final UnboundStringFlag CORE_ENCRYPTION_PUBLIC_KEY_ID = defineStringFlag( "core-encryption-public-key-id", "", - List.of("vekterli"), "2022-11-03", "2024-02-01", + List.of("vekterli"), "2022-11-03", "2024-06-01", "Specifies which public key to use for core dump encryption.", "Takes effect on the next tick.", NODE_TYPE, HOSTNAME); @@ -348,40 +348,40 @@ public class Flags { public static final UnboundBooleanFlag WRITE_CONFIG_SERVER_SESSION_DATA_AS_ONE_BLOB = defineFeatureFlag( "write-config-server-session-data-as-blob", false, - List.of("hmusum"), "2023-07-19", "2024-02-01", + List.of("hmusum"), "2023-07-19", "2024-03-01", "Whether to write config server session data in one blob or as individual paths", "Takes effect immediately"); public static final UnboundBooleanFlag READ_CONFIG_SERVER_SESSION_DATA_AS_ONE_BLOB = defineFeatureFlag( "read-config-server-session-data-as-blob", false, - List.of("hmusum"), "2023-07-19", "2024-02-01", + List.of("hmusum"), "2023-07-19", "2024-03-01", "Whether to read config server session data from session data blob or from individual paths", "Takes effect immediately"); public static final UnboundBooleanFlag MORE_WIREGUARD = defineFeatureFlag( "more-wireguard", false, - List.of("andreer"), "2023-08-21", "2024-01-24", + List.of("andreer"), "2023-08-21", "2024-02-24", "Use wireguard in INternal enCLAVES", "Takes effect on next host-admin run", HOSTNAME, CLOUD_ACCOUNT); public static final UnboundBooleanFlag IPV6_AWS_TARGET_GROUPS = defineFeatureFlag( "ipv6-aws-target-groups", false, - List.of("andreer"), "2023-08-28", "2024-01-24", + List.of("andreer"), "2023-08-28", "2024-02-24", "Always use IPv6 target groups for load balancers in aws", "Takes effect on next load-balancer provisioning", HOSTNAME, CLOUD_ACCOUNT); public static final UnboundBooleanFlag PROVISION_IPV6_ONLY_AWS = defineFeatureFlag( "provision-ipv6-only", false, - List.of("andreer"), "2023-08-28", "2024-01-24", + List.of("andreer"), "2023-08-28", "2024-02-24", "Provision without private IPv4 addresses in INternal enCLAVES in AWS", "Takes effect on next host provisioning / run of host-admin", HOSTNAME, CLOUD_ACCOUNT); public static final UnboundIntFlag CONTENT_LAYER_METADATA_FEATURE_LEVEL = defineIntFlag( "content-layer-metadata-feature-level", 0, - List.of("vekterli"), "2022-09-12", "2024-02-01", + List.of("vekterli"), "2022-09-12", "2024-06-01", "Value semantics: 0) legacy behavior, 1) operation cancellation, 2) operation " + "cancellation and ephemeral content node sequence numbers for bucket replicas", "Takes effect at redeployment", @@ -396,7 +396,7 @@ public class Flags { public static final UnboundStringFlag UNKNOWN_CONFIG_DEFINITION = defineStringFlag( "unknown-config-definition", "warn", - List.of("hmusum"), "2023-09-25", "2024-02-01", + List.of("hmusum"), "2023-09-25", "2024-03-01", "How to handle user config referencing unknown config definitions. Valid values are 'warn' and 'fail'", "Takes effect at redeployment", INSTANCE_ID); @@ -410,7 +410,7 @@ public class Flags { public static final UnboundStringFlag ENDPOINT_CONFIG = defineStringFlag( "endpoint-config", "legacy", - List.of("mpolden", "tokle"), "2023-10-06", "2024-02-01", + List.of("mpolden", "tokle"), "2023-10-06", "2024-06-01", "Set the endpoint config to use for an application. Must be 'legacy', 'combined' or 'generated'. See EndpointConfig for further details", "Takes effect on next deployment through controller", TENANT_ID, APPLICATION, INSTANCE_ID); @@ -421,25 +421,6 @@ public class Flags { "Whether to send cloud trial email notifications", "Takes effect immediately"); - public static final UnboundLongFlag MERGING_MAX_MEMORY_USAGE_PER_NODE = defineLongFlag( - "merging-max-memory-usage-per-node", -1, - List.of("vekterli"), "2023-11-03", "2024-03-01", - "Soft limit of the maximum amount of memory that can be used across merge operations on a content node. " + - "Value semantics: < 0: unlimited (legacy behavior), == 0: auto-deduced from node HW and config," + - " > 0: explicit memory usage limit in bytes.", - "Takes effect at redeployment", - INSTANCE_ID); - - public static final UnboundBooleanFlag USE_PER_DOCUMENT_THROTTLED_DELETE_BUCKET = defineFeatureFlag( - "use-per-document-throttled-delete-bucket", false, - List.of("vekterli"), "2023-11-13", "2024-03-01", - "If set, DeleteBucket operations are internally expanded to an individually persistence-" + - "throttled remove per document stored in the bucket. This makes the cost model of " + - "executing a DeleteBucket symmetrical with feeding the documents to the bucket in the " + - "first place.", - "Takes effect at redeployment", - INSTANCE_ID); - public static final UnboundBooleanFlag ENABLE_NEW_PAYMENT_METHOD_FLOW = defineFeatureFlag( "enable-new-payment-method-flow", false, List.of("bjorncs"), "2023-11-29", "2024-03-01", @@ -447,16 +428,9 @@ public class Flags { "Takes effect immediately", TENANT_ID, CONSOLE_USER_EMAIL); - public static final UnboundBooleanFlag CENTRALIZED_AUTHZ = defineFeatureFlag( - "centralized-authz", true, - List.of("mortent"), "2023-11-27", "2024-02-01", - "Use centralized authorization checks", - "Takes effect immediately", - CONSOLE_USER_EMAIL); - public static final UnboundBooleanFlag RESTART_ON_DEPLOY_WHEN_ONNX_MODEL_CHANGES = defineFeatureFlag( "restart-on-deploy-when-onnx-model-changes", true, - List.of("hmusum"), "2023-12-04", "2024-02-01", + List.of("hmusum"), "2023-12-04", "2024-03-01", "If set, restart on deploy if onnx model or onnx model options used by a container cluster change", "Takes effect at redeployment", INSTANCE_ID); diff --git a/flags/src/main/java/com/yahoo/vespa/flags/PermanentFlags.java b/flags/src/main/java/com/yahoo/vespa/flags/PermanentFlags.java index eb47c691334..4edda472531 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/PermanentFlags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/PermanentFlags.java @@ -13,6 +13,7 @@ import java.util.Set; import java.util.function.Predicate; import java.util.regex.Pattern; +import static com.yahoo.vespa.flags.Dimension.APPLICATION; import static com.yahoo.vespa.flags.Dimension.CLOUD_ACCOUNT; import static com.yahoo.vespa.flags.Dimension.INSTANCE_ID; import static com.yahoo.vespa.flags.Dimension.CLUSTER_ID; @@ -35,11 +36,16 @@ public class PermanentFlags { static final Instant CREATED_AT = Instant.EPOCH; static final Instant EXPIRES_AT = ZonedDateTime.of(2100, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC).toInstant(); + // TODO(mpolden): Remove this flag public static final UnboundBooleanFlag USE_ALTERNATIVE_ENDPOINT_CERTIFICATE_PROVIDER = defineFeatureFlag( "use-alternative-endpoint-certificate-provider", false, "Whether to use an alternative CA when provisioning new certificates", "Takes effect only on initial application deployment - not on later certificate refreshes!"); + public static final UnboundStringFlag ENDPOINT_CERTIFICATE_PROVIDER = defineStringFlag( + "endpoint-certificate-provider", "digicert", "The CA to use for endpoint certificates. Must be 'digicert', 'globalsign' or 'zerossl'", + "Takes effect on initial deployment", TENANT_ID, APPLICATION, INSTANCE_ID); + public static final UnboundStringFlag JVM_GC_OPTIONS = defineStringFlag( "jvm-gc-options", "", "Sets default jvm gc options", diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 1a9caaa5ca1..7c5e8912e49 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -97,7 +97,7 @@ public class EmbedExpression extends Expression { Tensor.Cell cell = cells.next(); builder.cell() .label(targetType.mappedSubtype().dimensions().get(0).name(), i) - .label(targetType.indexedSubtype().dimensions().get(0).name(), cell.getKey().label(0)) + .label(targetType.indexedSubtype().dimensions().get(0).name(), cell.getKey().numericLabel(0)) .value(cell.getValue()); } } diff --git a/integration/intellij/build.gradle.kts b/integration/intellij/build.gradle.kts index 980e8878efc..b4f2c92ec44 100644 --- a/integration/intellij/build.gradle.kts +++ b/integration/intellij/build.gradle.kts @@ -4,7 +4,7 @@ import org.jetbrains.grammarkit.tasks.GenerateParserTask plugins { id("java-library") - id("org.jetbrains.intellij") version "1.16.1" + id("org.jetbrains.intellij") version "1.17.0" id("org.jetbrains.grammarkit") version "2022.3.2.1" id("maven-publish") // to deploy the plugin into a Maven repo } diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index a805fc79a64..da3068c3744 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -105,8 +105,9 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, public Encoding encode(String text) { return encode(text, Language.UNKNOWN); } public Encoding encode(String text, Language language) { return Encoding.from(resolve(language).encode(text)); } - public String decode(List<Long> tokens) { return decode(tokens, Language.UNKNOWN); } - public String decode(List<Long> tokens, Language language) { return resolve(language).decode(toArray(tokens)); } + + public String decode(long [] tokens) { return decode(tokens, Language.UNKNOWN); } + public String decode(long [] tokens, Language language) { return resolve(language).decode(tokens); } @Override public void close() { diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index 1ffb879e57e..dc6a62cc463 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -344,7 +344,9 @@ "public java.lang.String getDestination()", "public com.yahoo.language.process.Embedder$Context setDestination(java.lang.String)", "public java.lang.String getEmbedderId()", - "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)" + "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)", + "public java.util.Map getContextValues()", + "public com.yahoo.language.process.Embedder$Context setContextValues(java.util.Map)" ], "fields" : [ ] }, diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java index fa141977d5d..f6fdb86d01f 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java @@ -88,6 +88,7 @@ public interface Embedder { private Language language = Language.UNKNOWN; private String destination; private String embedderId = "unknown"; + private Map<String, String> contextValues; public Context(String destination) { this.destination = destination; @@ -138,6 +139,15 @@ public interface Embedder { this.embedderId = embedderId; return this; } + + /** Returns a read-only map of context key-values which can be looked up during conversion. */ + public Map<String, String> getContextValues() { return contextValues; } + + public Context setContextValues(Map<String, String> contextValues) { + this.contextValues = contextValues; + return this; + } + } class FailingEmbedder implements Embedder { diff --git a/metrics/src/main/java/ai/vespa/metrics/Labels.java b/metrics/src/main/java/ai/vespa/metrics/Labels.java new file mode 100644 index 00000000000..98ab105cd66 --- /dev/null +++ b/metrics/src/main/java/ai/vespa/metrics/Labels.java @@ -0,0 +1,52 @@ +package ai.vespa.metrics; +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +/** + * @author yngveaasheim + */ +public enum Labels { + + // Any changes to labels in the ai.vespa namespace needs to be approved in an architecture reviews. + // We try to follow recommendations outlined in OpenTelemetry Semantic Conventions for these labels, https://opentelemetry.io/docs/specs/semconv/. + + // Labels to be decorated onto all tenant related metrics generated for Vespa Cloud. + CLUSTER("ai.vespa.cluster", "The name of a Vespa cluster."), + CLUSTER_TYPE("ai.vespa.cluster_type", "The type of a Vespa cluster, typically one of 'admin', 'container', 'content'."), + DEPLOYMENT_CLUSTER("ai.vespa.deployment_cluster", "Unique ID for a Vespa deployment cluster, in the format <tenant>.<application>.<instance>.<zone>.<cluster>."), + INSTANCE_ID("ai.vespa.instance_id", "The id of a Vespa application instance in the format <tenant>.<application>.<instance>."), + GROUP("ai.vespa.group", "The group id of a Vespa content node. Samples values are 'Group 1', 'Group 2', etc."), + SYSTEM("ai.vespa.system", "The name of a managed Vespa system, sample values are 'public', 'publiccd'."), + ZONE("ai.vespa.zone", "The name of a zone in managed Vespa, in the format <environment>.<region>. Sample name 'prod.aws-us-west-2a'."), + PARENT("ai.vespa.parent", "The fully qualified name of the parent host on which a Vespa node is running."), + + // Labels used for a subset of the metrics only: + CHAIN("ai.vespa.chain", "The name of a search chain"), + DOCUMENT_PROCESSOR("ai.vespa.document_processor", "Document processor name."), + SERVICE("ai.vespa.service", "Vespa service name, e.g. 'container', 'distributor', 'searchnode'."), + THREAD_POOL("ai.vespa.thread_pool", "Thread pool name."), + VERSION("ai.vespa.version", "Version of Vespa running on a node."), + + // TODO: Add other labels used by the metrics in the summary dashboard: "api, gcName, gpu, interface, operation, protocol(verify, requestType, role, scheme, status(verify) + + // Labels defined by OpenTelemetry Semantic Conventions external to Vespa + HOST_ARCH("host.arch", "The CPU architecture of a host, e.g. 'x86_64', 'arm64'. See also https://opentelemetry.io/docs/specs/semconv/resource/host/"), + HOST_NAME("host.name", "The fully qualified name of a host. See also https://opentelemetry.io/docs/specs/semconv/resource/host/"), + HOST_TYPE("host.type", "The type of a host. See also https://opentelemetry.io/docs/specs/semconv/resource/host/"), + HTTP_REQUEST_METHOD("http.request.method", "The HTTP request method specified. See also https://opentelemetry.io/docs/specs/semconv/attributes-registry/http/"), + HTTP_RESPONSE_STATUS_CODE("http.response.status_code", "The HTTP response code. See also https://opentelemetry.io/docs/specs/semconv/attributes-registry/http/"); + + private final String name; + private final String description; + + public String getName() { + return name; + } + public String getDescription() { + return description; + } + + Labels(String name, String description) { + this.name = name; + this.description = description; + } +} diff --git a/metrics/src/main/java/ai/vespa/metrics/StorageMetrics.java b/metrics/src/main/java/ai/vespa/metrics/StorageMetrics.java index dd8c2a2a1af..4d91cc1d989 100644 --- a/metrics/src/main/java/ai/vespa/metrics/StorageMetrics.java +++ b/metrics/src/main/java/ai/vespa/metrics/StorageMetrics.java @@ -152,6 +152,7 @@ public enum StorageMetrics implements VespaMetrics { VDS_DATASTORED_BUCKET_SPACE_BUCKET_DB_MEMORY_USAGE_ONHOLD_BYTES("vds.datastored.bucket_space.bucket_db.memory_usage.onhold_bytes", Unit.BYTE, "The number of bytes on hold"), VDS_DATASTORED_BUCKET_SPACE_BUCKET_DB_MEMORY_USAGE_USED_BYTES("vds.datastored.bucket_space.bucket_db.memory_usage.used_bytes", Unit.BYTE, "The number of used bytes (<= allocated_bytes)"), VDS_DATASTORED_BUCKET_SPACE_BUCKETS_TOTAL("vds.datastored.bucket_space.buckets_total", Unit.BUCKET, "Total number buckets present in the bucket space (ready + not ready)"), + VDS_DATASTORED_BUCKET_SPACE_ENTRIES("vds.datastored.bucket_space.entries", Unit.DOCUMENT, "Number of entries (documents + tombstones) stored in the bucket space"), VDS_DATASTORED_BUCKET_SPACE_BYTES("vds.datastored.bucket_space.bytes", Unit.BYTE, "Bytes stored across all documents in the bucket space"), VDS_DATASTORED_BUCKET_SPACE_DOCS("vds.datastored.bucket_space.docs", Unit.DOCUMENT, "Documents stored in the bucket space"), VDS_DATASTORED_BUCKET_SPACE_READY_BUCKETS("vds.datastored.bucket_space.ready_buckets", Unit.BUCKET, "Number of ready buckets in the bucket space"), diff --git a/metrics/src/main/java/ai/vespa/metrics/set/InfrastructureMetricSet.java b/metrics/src/main/java/ai/vespa/metrics/set/InfrastructureMetricSet.java index db70aaa63c7..a9d078e2a44 100644 --- a/metrics/src/main/java/ai/vespa/metrics/set/InfrastructureMetricSet.java +++ b/metrics/src/main/java/ai/vespa/metrics/set/InfrastructureMetricSet.java @@ -116,7 +116,7 @@ public class InfrastructureMetricSet { addMetric(metrics, ConfigServerMetrics.THROTTLED_HOST_FAILURES.max()); addMetric(metrics, ConfigServerMetrics.THROTTLED_NODE_FAILURES.max()); addMetric(metrics, ConfigServerMetrics.NODE_FAIL_THROTTLING.max()); - addMetric(metrics, ConfigServerMetrics.CLUSTER_AUTOSCALED.max()); + addMetric(metrics, ConfigServerMetrics.CLUSTER_AUTOSCALED.count()); addMetric(metrics, ConfigServerMetrics.ORCHESTRATOR_LOCK_ACQUIRE_SUCCESS.count()); addMetric(metrics, ConfigServerMetrics.ORCHESTRATOR_LOCK_ACQUIRE_TIMEOUT.count()); diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 644b1ec538f..3a64083c623 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -10,9 +10,12 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.embedding.SpladeEmbedderConfig; import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.DirectIndexedAddress; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; + import java.nio.file.Paths; import java.util.List; import java.util.Map; @@ -31,17 +34,22 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { private final String tokenTypeIdsName; private final String outputName; private final double termScoreThreshold; + private final boolean useCustomReduce; private final HuggingFaceTokenizer tokenizer; private final OnnxEvaluator evaluator; @Inject public SpladeEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, SpladeEmbedderConfig config) { + this(onnx, runtime, config, true); + } + SpladeEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, SpladeEmbedderConfig config, boolean useCustomReduce) { this.runtime = runtime; inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); outputName = config.transformerOutput(); tokenTypeIdsName = config.transformerTokenTypeIds(); termScoreThreshold = config.termScoreThreshold(); + this.useCustomReduce = useCustomReduce; var tokenizerPath = Paths.get(config.tokenizerPath().toString()); var builder = new HuggingFaceTokenizer.Builder() @@ -116,20 +124,54 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { Map<String, Tensor> inputs = Map.of(inputIdsName, inputSequence.expand("d0"), attentionMaskName, attentionMask.expand("d0"), tokenTypeIdsName, tokenTypeIds.expand("d0")); - Tensor spladeTensor = sparsify((IndexedTensor) evaluator.evaluate(inputs).get(outputName), tensorType); + IndexedTensor output = (IndexedTensor) evaluator.evaluate(inputs).get(outputName); + Tensor spladeTensor = useCustomReduce + ? sparsifyCustomReduce(output, tensorType) + : sparsifyReduce(output, tensorType); runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); return spladeTensor; } /** - * Sparsify the model output tensor. + * Sparsify the output tensor by applying a threshold on the log of the relu of the output. + * This uses generic tensor reduce+map, and is slightly slower than a custom unrolled variant. + * @param modelOutput the model output tensor of shape d1,dim where d1 is the sequence length and dim is size + * of the vocabulary + * @param tensorType the type of the destination tensor + * @return A mapped tensor with the terms from the vocab that has a score above the threshold + */ + private Tensor sparsifyReduce(Tensor modelOutput, TensorType tensorType) { + //Remove batch dim, batch size of 1 + Tensor output = modelOutput.reduce(Reduce.Aggregator.max, "d0", "d1"); + Tensor logOfRelu = output.map((x) -> Math.log(1 + (x > 0 ? x : 0))); + IndexedTensor vocab = (IndexedTensor) logOfRelu; + var builder = Tensor.Builder.of(tensorType); + long[] tokens = new long[1]; + for (int i = 0; i < vocab.size(); i++) { + var score = vocab.get(i); + if (score > termScoreThreshold) { + tokens[0] = i; + String term = tokenizer.decode(tokens); + builder.cell(). + label(tensorType.dimensions().get(0).name(), term) + .value(score); + } + } + return builder.build(); + } + + + + /** + * Sparsify the model output tensor.This uses an unrolled custom reduce and is 15-20% faster than the using + * generic tensor reduce. * * @param modelOutput the model output tensor of type tensorType * @param tensorType the type of the destination tensor * @return A mapped tensor with the terms from the vocab that has a score above the threshold */ - public Tensor sparsify(IndexedTensor modelOutput, TensorType tensorType) { + public Tensor sparsifyCustomReduce(IndexedTensor modelOutput, TensorType tensorType) { var builder = Tensor.Builder.of(tensorType); long[] shape = modelOutput.shape(); if(shape.length != 3) { @@ -139,24 +181,38 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { if (batch != 1) { throw new IllegalArgumentException("Batch size must be 1"); } - long sequenceLength = shape[1]; - long vocabSize = shape[2]; + if (shape[1] > Integer.MAX_VALUE) { + throw new IllegalArgumentException("sequenceLength=" + shape[1] + " larger than an int"); + } + if (shape[2] > Integer.MAX_VALUE) { + throw new IllegalArgumentException("vocabSize=" + shape[2] + " larger than an int"); + } + int sequenceLength = (int) shape[1]; + int vocabSize = (int) shape[2]; + String dimension = tensorType.dimensions().get(0).name(); //Iterate over the vocab dimension and find the max value for each sequence token - for(int v = 0; v < vocabSize; v++) { - double maxLogOfRelu = Double.MIN_VALUE; - for(int s = 0; s < sequenceLength; s++) { - double value = modelOutput.get(0, s, v); // batch, sequence, vocab - double logOfRelu = Math.log(1 + Math.max(0, value)); - if(logOfRelu > maxLogOfRelu) { - maxLogOfRelu = logOfRelu; + long [] tokens = new long[1]; + DirectIndexedAddress directAddress = modelOutput.directAddress(); + directAddress.setIndex(0,0); + for (int v = 0; v < vocabSize; v++) { + double maxValue = 0.0d; + directAddress.setIndex(2, v); + long increment = directAddress.getStride(1); + long directIndex = directAddress.getDirectIndex(); + for (int s = 0; s < sequenceLength; s++) { + double value = modelOutput.get(directIndex + s * increment); + if (value > maxValue) { + maxValue = value; } } - if (maxLogOfRelu > termScoreThreshold) { - String term = tokenizer.decode(List.of((long) v)); - builder.cell(). - label(tensorType.dimensions().get(0).name(), term) - .value(maxLogOfRelu); + double logOfRelu = Math.log(1 + maxValue); + if (logOfRelu > termScoreThreshold) { + tokens[0] = v; + String term = tokenizer.decode(tokens); + builder.cell() + .label(dimension, term) + .value(logOfRelu); } } return builder.build(); diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index 2612702e99b..d1a06d8c7ff 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -28,6 +28,7 @@ import java.nio.ShortBuffer; import java.util.HashMap; import java.util.Map; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; @@ -53,10 +54,9 @@ class TensorConverter { static OnnxTensor toOnnxTensor(Tensor vespaTensor, TensorInfo onnxTensorInfo, OrtEnvironment environment) throws OrtException { - if ( ! (vespaTensor instanceof IndexedTensor)) { + if ( ! (vespaTensor instanceof IndexedTensor tensor)) { throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions"); } - IndexedTensor tensor = (IndexedTensor) vespaTensor; ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder()); if (onnxTensorInfo.type == OnnxJavaType.FLOAT) { for (int i = 0; i < tensor.size(); i++) @@ -102,70 +102,67 @@ class TensorConverter { throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type); } + interface Short2Float { + float convert(short value); + } + + private static void extractTensor(FloatBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(DoubleBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ByteBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ShortBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ShortBuffer buffer, Short2Float converter, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, converter.convert(buffer.get(i))); + } + private static void extractTensor(IntBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(LongBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + static Tensor toVespaTensor(OnnxValue onnxValue) { - if ( ! (onnxValue instanceof OnnxTensor)) { + if ( ! (onnxValue instanceof OnnxTensor onnxTensor)) { throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported"); } - OnnxTensor onnxTensor = (OnnxTensor) onnxValue; TensorInfo tensorInfo = onnxTensor.getInfo(); - TensorType type = toVespaType(onnxTensor.getInfo()); - DimensionSizes sizes = sizesFromType(type); - + DimensionSizes sizes = DimensionSizes.of(type); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type, sizes); - if (tensorInfo.type == OnnxJavaType.FLOAT) { - FloatBuffer buffer = onnxTensor.getFloatBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.DOUBLE) { - DoubleBuffer buffer = onnxTensor.getDoubleBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT8) { - ByteBuffer buffer = onnxTensor.getByteBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT32) { - IntBuffer buffer = onnxTensor.getIntBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT64) { - LongBuffer buffer = onnxTensor.getLongBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.FLOAT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get())); - } - else if (tensorInfo.type == OnnxJavaType.BFLOAT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get()))); - } - else { - throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); + long totalSizeAsLong = sizes.totalSize(); + if (totalSizeAsLong > Integer.MAX_VALUE) { + throw new IllegalArgumentException("TotalSize=" + totalSizeAsLong + " currently limited at INTEGER.MAX_VALUE"); + } + + int totalSize = (int) totalSizeAsLong; + switch (tensorInfo.type) { + case FLOAT -> extractTensor(onnxTensor.getFloatBuffer(), builder, totalSize); + case DOUBLE -> extractTensor(onnxTensor.getDoubleBuffer(), builder, totalSize); + case INT8 -> extractTensor(onnxTensor.getByteBuffer(), builder, totalSize); + case INT16 -> extractTensor(onnxTensor.getShortBuffer(), builder, totalSize); + case INT32 -> extractTensor(onnxTensor.getIntBuffer(), builder, totalSize); + case INT64 -> extractTensor(onnxTensor.getLongBuffer(), builder, totalSize); + case FLOAT16 -> extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::fp16ToFloat, builder, totalSize); + case BFLOAT16 -> extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::bf16ToFloat, builder, totalSize); + default -> throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); } return builder.build(); } - static private DimensionSizes sizesFromType(TensorType type) { - DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); - for (int i = 0; i < type.dimensions().size(); i++) - builder.set(i, type.dimensions().get(i).size().get()); - return builder.build(); - } - static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) { return infoMap.entrySet().stream().collect(Collectors.toMap(e -> asValidName(e.getKey()), e -> toVespaType(e.getValue().getInfo()))); @@ -201,14 +198,14 @@ class TensorConverter { } static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) { - switch (onnxType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: return TensorType.Value.FLOAT; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE; - } - return TensorType.Value.DOUBLE; + return switch (onnxType) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 -> TensorType.Value.INT8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 -> TensorType.Value.BFLOAT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 -> TensorType.Value.FLOAT; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT -> TensorType.Value.FLOAT; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE -> TensorType.Value.DOUBLE; + default -> TensorType.Value.DOUBLE; + }; } static private TensorInfo toTensorInfo(ValueInfo valueInfo) { diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java index 9ecb0e3e162..e2b1caf4441 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java @@ -48,12 +48,12 @@ public class SpladeEmbedderTest { public void testPerformanceNotTerrible() { String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" + " ii the project was led by the united states with the support of the united kingdom and canada"; - Long now = System.currentTimeMillis(); - int n = 10; + long now = System.currentTimeMillis(); + int n = 1000; // 7s on Intel core i9 2.4Ghz (macbook pro, 2019) using custom reduce, 8s if using generic reduce for (int i = 0; i < n; i++) { assertEmbed("tensor<float>(t{})", text, indexingContext); } - Long elapsed = (System.currentTimeMillis() - now)/1000; + long elapsed = System.currentTimeMillis() - now; System.out.println("Elapsed time: " + elapsed + " ms"); } @@ -72,9 +72,11 @@ public class SpladeEmbedderTest { static { indexingContext = new Embedder.Context("schema.indexing"); - spladeEmbedder = getEmbedder(); + // Custom reduce is 14% faster than generic reduce and the default. + // Keeping as option for performance testing + spladeEmbedder = getEmbedder(false); } - private static Embedder getEmbedder() { + private static Embedder getEmbedder(boolean useCustomReduce) { String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json"; String modelPath = "src/test/models/onnx/transformer/dummy_transformer_mlm.onnx"; assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); @@ -83,6 +85,6 @@ public class SpladeEmbedderTest { builder.transformerModel(ModelReference.valueOf(modelPath)); builder.termScoreThreshold(scoreThreshold); builder.transformerGpuDevice(-1); - return new SpladeEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); + return new SpladeEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build(), useCustomReduce); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/applications/Cluster.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/applications/Cluster.java index 606605ed1e4..4134ea337ab 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/applications/Cluster.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/applications/Cluster.java @@ -34,6 +34,7 @@ public class Cluster { private final IntRange groupSize; private final boolean required; private final Autoscaling suggested; + private final List<Autoscaling> suggestions; private final Autoscaling target; private final ClusterInfo clusterInfo; private final BcpGroupInfo bcpGroupInfo; @@ -48,6 +49,7 @@ public class Cluster { IntRange groupSize, boolean required, Autoscaling suggested, + List<Autoscaling> suggestions, Autoscaling target, ClusterInfo clusterInfo, BcpGroupInfo bcpGroupInfo, @@ -59,6 +61,7 @@ public class Cluster { this.groupSize = Objects.requireNonNull(groupSize); this.required = required; this.suggested = Objects.requireNonNull(suggested); + this.suggestions = Objects.requireNonNull(suggestions); Objects.requireNonNull(target); if (target.resources().isPresent() && ! target.resources().get().isWithin(minResources, maxResources)) this.target = target.withResources(Optional.empty()); // Delete illegal target @@ -102,12 +105,21 @@ public class Cluster { */ public Autoscaling suggested() { return suggested; } + /** + * The list of suggested resources, which may or may not be within the min and max limits, + * or empty if there is currently no recorded suggestion. + * List is sorted by preference + */ + public List<Autoscaling> suggestions() { return suggestions; } + /** Returns true if there is a current suggestion and we should actually make this suggestion to users. */ public boolean shouldSuggestResources(ClusterResources currentResources) { - if (suggested.resources().isEmpty()) return false; - if (suggested.resources().get().isWithin(min, max)) return false; - if ( ! Autoscaler.worthRescaling(currentResources, suggested.resources().get())) return false; - return true; + if (suggestions.isEmpty()) return false; + return suggestions.stream().noneMatch(suggestion -> + suggestion.resources().isEmpty() + || suggestion.resources().get().isWithin(min, max) + || ! Autoscaler.worthRescaling(currentResources, suggestion.resources().get()) + ); } public ClusterInfo clusterInfo() { return clusterInfo; } @@ -131,19 +143,23 @@ public class Cluster { public Cluster withConfiguration(boolean exclusive, Capacity capacity) { return new Cluster(id, exclusive, capacity.minResources(), capacity.maxResources(), capacity.groupSize(), capacity.isRequired(), - suggested, target, capacity.clusterInfo(), bcpGroupInfo, scalingEvents); + suggested, suggestions, target, capacity.clusterInfo(), bcpGroupInfo, scalingEvents); } public Cluster withSuggested(Autoscaling suggested) { - return new Cluster(id, exclusive, min, max, groupSize, required, suggested, target, clusterInfo, bcpGroupInfo, scalingEvents); + return new Cluster(id, exclusive, min, max, groupSize, required, suggested, suggestions, target, clusterInfo, bcpGroupInfo, scalingEvents); + } + + public Cluster withSuggestions(List<Autoscaling> suggestions) { + return new Cluster(id, exclusive, min, max, groupSize, required, suggested, suggestions, target, clusterInfo, bcpGroupInfo, scalingEvents); } public Cluster withTarget(Autoscaling target) { - return new Cluster(id, exclusive, min, max, groupSize, required, suggested, target, clusterInfo, bcpGroupInfo, scalingEvents); + return new Cluster(id, exclusive, min, max, groupSize, required, suggested, suggestions, target, clusterInfo, bcpGroupInfo, scalingEvents); } public Cluster with(BcpGroupInfo bcpGroupInfo) { - return new Cluster(id, exclusive, min, max, groupSize, required, suggested, target, clusterInfo, bcpGroupInfo, scalingEvents); + return new Cluster(id, exclusive, min, max, groupSize, required, suggested, suggestions, target, clusterInfo, bcpGroupInfo, scalingEvents); } /** Add or update (based on "at" time) a scaling event */ @@ -157,7 +173,7 @@ public class Cluster { scalingEvents.add(scalingEvent); prune(scalingEvents); - return new Cluster(id, exclusive, min, max, groupSize, required, suggested, target, clusterInfo, bcpGroupInfo, scalingEvents); + return new Cluster(id, exclusive, min, max, groupSize, required, suggested, suggestions, target, clusterInfo, bcpGroupInfo, scalingEvents); } @Override @@ -189,7 +205,7 @@ public class Cluster { public static Cluster create(ClusterSpec.Id id, boolean exclusive, Capacity requested) { return new Cluster(id, exclusive, requested.minResources(), requested.maxResources(), requested.groupSize(), requested.isRequired(), - Autoscaling.empty(), Autoscaling.empty(), requested.clusterInfo(), BcpGroupInfo.empty(), List.of()); + Autoscaling.empty(), List.of(), Autoscaling.empty(), requested.clusterInfo(), BcpGroupInfo.empty(), List.of()); } /** The predicted time it will take to rescale this cluster. */ diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/AllocationOptimizer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/AllocationOptimizer.java index ff30f9d6163..ae12ca13318 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/AllocationOptimizer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/AllocationOptimizer.java @@ -6,7 +6,10 @@ import com.yahoo.config.provision.IntRange; import com.yahoo.config.provision.NodeResources; import com.yahoo.vespa.hosted.provision.NodeRepository; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; +import java.util.stream.Collectors; import static com.yahoo.vespa.hosted.provision.autoscale.Autoscaler.headroomRequiredToScaleDown; @@ -37,13 +40,26 @@ public class AllocationOptimizer { public Optional<AllocatableResources> findBestAllocation(Load loadAdjustment, ClusterModel model, Limits limits) { + return findBestAllocations(loadAdjustment, model, limits).stream().findFirst(); + } + + /** + * Searches the space of possible allocations given a target relative load + * and (optionally) cluster limits and returns the best alternative. + * + * @return the best allocations, if there are any possible legal allocations, fulfilling the target + * fully or partially, within the limits. The list contains the three best allocations, sorted from most to least preferred. + */ + public List<AllocatableResources> findBestAllocations(Load loadAdjustment, + ClusterModel model, + Limits limits) { if (limits.isEmpty()) limits = Limits.of(new ClusterResources(minimumNodes, 1, NodeResources.unspecified()), new ClusterResources(maximumNodes, maximumNodes, NodeResources.unspecified()), IntRange.empty()); else limits = atLeast(minimumNodes, limits).fullySpecified(model.current().clusterSpec(), nodeRepository, model.application().id()); - Optional<AllocatableResources> bestAllocation = Optional.empty(); + List<AllocatableResources> bestAllocations = new ArrayList<>(); var availableRealHostResources = nodeRepository.zone().cloud().dynamicProvisioning() ? nodeRepository.flavors().getFlavors().stream().map(flavor -> flavor.resources()).toList() : nodeRepository.nodes().list().hosts().stream().map(host -> host.flavor().resources()) @@ -65,11 +81,20 @@ public class AllocationOptimizer { model, nodeRepository); if (allocatableResources.isEmpty()) continue; - if (bestAllocation.isEmpty() || allocatableResources.get().preferableTo(bestAllocation.get(), model)) - bestAllocation = allocatableResources; + bestAllocations.add(allocatableResources.get()); } } - return bestAllocation; + return bestAllocations.stream() + .sorted((one, other) -> { + if (one.preferableTo(other, model)) + return -1; + else if (other.preferableTo(one, model)) { + return 1; + } + return 0; + }) + .limit(3) + .toList(); } /** Returns the max resources of a host one node may allocate. */ diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java index 738abddc31a..40819e709de 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java @@ -9,6 +9,7 @@ import com.yahoo.vespa.hosted.provision.applications.Cluster; import com.yahoo.vespa.hosted.provision.autoscale.Autoscaling.Status; import java.time.Duration; +import java.util.List; /** * The autoscaler gives advice about what resources should be allocated to a cluster based on observed behavior. @@ -39,8 +40,14 @@ public class Autoscaler { * @param clusterNodes the list of all the active nodes in a cluster * @return scaling advice for this cluster */ - public Autoscaling suggest(Application application, Cluster cluster, NodeList clusterNodes) { - return autoscale(application, cluster, clusterNodes, Limits.empty()); + public List<Autoscaling> suggest(Application application, Cluster cluster, NodeList clusterNodes) { + var model = model(application, cluster, clusterNodes); + if (model.isEmpty() || ! model.isStable(nodeRepository)) return List.of(); + + var targets = allocationOptimizer.findBestAllocations(model.loadAdjustment(), model, Limits.empty()); + return targets.stream() + .map(target -> toAutoscaling(target, model)) + .toList(); } /** @@ -50,18 +57,8 @@ public class Autoscaler { * @return scaling advice for this cluster */ public Autoscaling autoscale(Application application, Cluster cluster, NodeList clusterNodes) { - return autoscale(application, cluster, clusterNodes, Limits.of(cluster)); - } - - private Autoscaling autoscale(Application application, Cluster cluster, NodeList clusterNodes, Limits limits) { - var model = new ClusterModel(nodeRepository, - application, - clusterNodes.not().retired().clusterSpec(), - cluster, - clusterNodes, - new AllocatableResources(clusterNodes.not().retired(), nodeRepository), - nodeRepository.metricsDb(), - nodeRepository.clock()); + var limits = Limits.of(cluster); + var model = model(application, cluster, clusterNodes); if (model.isEmpty()) return Autoscaling.empty(); if (! limits.isEmpty() && cluster.minResources().equals(cluster.maxResources())) @@ -78,18 +75,33 @@ public class Autoscaler { if (target.isEmpty()) return Autoscaling.dontScale(Status.insufficient, "No allocations are possible within configured limits", model); - if (target.get().nodes() == 1) + return toAutoscaling(target.get(), model); + } + + private ClusterModel model(Application application, Cluster cluster, NodeList clusterNodes) { + return new ClusterModel(nodeRepository, + application, + clusterNodes.not().retired().clusterSpec(), + cluster, + clusterNodes, + new AllocatableResources(clusterNodes.not().retired(), nodeRepository), + nodeRepository.metricsDb(), + nodeRepository.clock()); + } + + private Autoscaling toAutoscaling(AllocatableResources target, ClusterModel model) { + if (target.nodes() == 1) return Autoscaling.dontScale(Status.unavailable, "Autoscaling is disabled in single node clusters", model); - if (! worthRescaling(model.current().realResources(), target.get().realResources())) { - if (target.get().fulfilment() < 0.9999999) + if (! worthRescaling(model.current().realResources(), target.realResources())) { + if (target.fulfilment() < 0.9999999) return Autoscaling.dontScale(Status.insufficient, "Configured limits prevents ideal scaling of this cluster", model); else if ( ! model.safeToScaleDown() && model.idealLoad().any(v -> v < 1.0)) return Autoscaling.dontScale(Status.ideal, "Cooling off before considering to scale down", model); else return Autoscaling.dontScale(Status.ideal, "Cluster is ideally scaled (within configured limits)", model); } - return Autoscaling.scaleTo(target.get().advertisedResources(), model); + return Autoscaling.scaleTo(target.advertisedResources(), model); } /** Returns true if it is worthwhile to make the given resource change, false if it is too insignificant */ diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDb.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDb.java index 4d0bbb4e511..e9230d2c91a 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDb.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDb.java @@ -431,7 +431,7 @@ public class QuestMetricsDb extends AbstractComponent implements MetricsDb { try { issueAsync("alter table " + name + " drop partition where at < dateadd('d', -4, now());", newContext()); } - catch (SqlException e) { + catch (Exception e) { if (e.getMessage().contains("no partitions matched WHERE clause")) return; log.log(Level.WARNING, "Failed to gc old metrics data in " + dir + " table " + name, e); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java index e6d0c339e5a..2bec9aa6115 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java @@ -158,7 +158,15 @@ public class AutoscalingMaintainer extends NodeRepositoryMaintainer { log.info("Autoscaling " + application + " " + clusterNodes.clusterSpec() + ":" + "\nfrom " + toString(from) + "\nto " + toString(to)); metric.add(ConfigServerMetrics.CLUSTER_AUTOSCALED.baseName(), 1, - metric.createContext(MetricsReporter.dimensions(application, clusterNodes.clusterSpec().id()))); + metric.createContext(dimensions(application, clusterNodes.clusterSpec()))); + } + + private static Map<String, String> dimensions(ApplicationId application, ClusterSpec clusterSpec) { + return Map.of("tenantName", application.tenant().value(), + "applicationId", application.serializedForm().replace(':', '.'), + "app", application.application().value() + "." + application.instance().value(), + "clusterid", clusterSpec.id().value(), + "clustertype", clusterSpec.type().name()); } static String toString(ClusterResources r) { diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java index b644e5d8a08..b82d1809085 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java @@ -408,7 +408,7 @@ public class MetricsReporter extends NodeRepositoryMaintainer { metric.set(ConfigServerMetrics.NODES_EMPTY_EXCLUSIVE.baseName(), emptyHosts, null); } - public static Map<String, String> dimensions(ApplicationId application, ClusterSpec.Id cluster) { + static Map<String, String> dimensions(ApplicationId application, ClusterSpec.Id cluster) { Map<String, String> dimensions = new HashMap<>(dimensions(application)); dimensions.put("clusterid", cluster.value()); return dimensions; diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainer.java index fd93d202795..fa1be83dbcf 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainer.java @@ -16,6 +16,7 @@ import com.yahoo.vespa.hosted.provision.autoscale.Autoscaler; import com.yahoo.vespa.hosted.provision.autoscale.Autoscaling; import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -63,13 +64,13 @@ public class ScalingSuggestionsMaintainer extends NodeRepositoryMaintainer { Application application = applications().get(applicationId).orElse(Application.empty(applicationId)); Optional<Cluster> cluster = application.cluster(clusterId); if (cluster.isEmpty()) return true; - var suggestion = autoscaler.suggest(application, cluster.get(), clusterNodes); - if (suggestion.status() == Autoscaling.Status.waiting) return true; - if ( ! shouldUpdateSuggestion(cluster.get().suggested(), suggestion)) return true; + var suggestions = autoscaler.suggest(application, cluster.get(), clusterNodes); + if ( ! shouldUpdateSuggestion(cluster.get().suggestions(), suggestions)) + return true; // Wait only a short time for the lock to avoid interfering with change deployments try (Mutex lock = nodeRepository().applications().lock(applicationId, Duration.ofSeconds(1))) { - applications().get(applicationId).ifPresent(a -> updateSuggestion(suggestion, clusterId, a, lock)); + applications().get(applicationId).ifPresent(a -> updateSuggestion(suggestions, clusterId, a, lock)); return true; } catch (ApplicationLockException e) { @@ -77,19 +78,28 @@ public class ScalingSuggestionsMaintainer extends NodeRepositoryMaintainer { } } - private boolean shouldUpdateSuggestion(Autoscaling currentSuggestion, Autoscaling newSuggestion) { - return currentSuggestion.resources().isEmpty() - || currentSuggestion.at().isBefore(nodeRepository().clock().instant().minus(Duration.ofDays(7))) - || (newSuggestion.resources().isPresent() && isHigher(newSuggestion.resources().get(), currentSuggestion.resources().get())); + private boolean shouldUpdateSuggestion(List<Autoscaling> currentSuggestions, List<Autoscaling> newSuggestions) { + // Only compare previous best suggestion with current best suggestion + var currentSuggestion = currentSuggestions.stream().findFirst(); + var newSuggestion = newSuggestions.stream().findFirst(); + + if (currentSuggestion.isEmpty()) return true; + if (newSuggestion.isEmpty()) return false; + + return newSuggestion.get().status() != Autoscaling.Status.waiting + && (currentSuggestion.get().resources().isEmpty() + || currentSuggestion.get().at().isBefore(nodeRepository().clock().instant().minus(Duration.ofDays(7))) + || (newSuggestion.get().resources().isPresent() && isHigher(newSuggestion.get().resources().get(), currentSuggestion.get().resources().get()))); } - private void updateSuggestion(Autoscaling autoscaling, + private void updateSuggestion(List<Autoscaling> suggestions, ClusterSpec.Id clusterId, Application application, Mutex lock) { Optional<Cluster> cluster = application.cluster(clusterId); if (cluster.isEmpty()) return; - applications().put(application.with(cluster.get().withSuggested(autoscaling)), lock); + applications().put(application.with(cluster.get().withSuggestions(suggestions) + .withSuggested(suggestions.stream().findFirst().orElse(Autoscaling.empty()))), lock); } private boolean isHigher(ClusterResources r1, ClusterResources r2) { diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingOsUpgrader.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingOsUpgrader.java index 5d8296d6f9d..bd7f39b1f6e 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingOsUpgrader.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingOsUpgrader.java @@ -38,7 +38,7 @@ public class DelegatingOsUpgrader extends OsUpgrader { .matching(node -> canUpgradeTo(target.version(), now, node)) .byIncreasingOsVersion() .first(upgradeSlots(target, activeNodes)); - if (nodesToUpgrade.size() == 0) return; + if (nodesToUpgrade.isEmpty()) return; LOG.info("Upgrading " + nodesToUpgrade.size() + " nodes of type " + target.nodeType() + " to OS version " + target.version().toFullString()); nodeRepository.nodes().upgradeOs(NodeListFilter.from(nodesToUpgrade.asList()), Optional.of(target.version())); @@ -49,7 +49,7 @@ public class DelegatingOsUpgrader extends OsUpgrader { NodeList nodesUpgrading = nodeRepository.nodes().list() .nodeType(type) .changingOsVersion(); - if (nodesUpgrading.size() == 0) return; + if (nodesUpgrading.isEmpty()) return; LOG.info("Disabling OS upgrade of all " + type + " nodes"); nodeRepository.nodes().upgradeOs(NodeListFilter.from(nodesUpgrading.asList()), Optional.empty()); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/OsUpgrader.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/OsUpgrader.java index f56e75518a3..2f09a0a5a29 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/OsUpgrader.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/OsUpgrader.java @@ -44,7 +44,7 @@ public abstract class OsUpgrader { /** Returns the number of upgrade slots available for given target */ final int upgradeSlots(OsVersionTarget target, NodeList candidates) { if (!candidates.stream().allMatch(node -> node.type() == target.nodeType())) { - throw new IllegalArgumentException("All node types must type of OS version target " + target.nodeType()); + throw new IllegalArgumentException("All node types must match type of OS version target " + target.nodeType()); } int max = target.nodeType() == NodeType.host ? maxActiveUpgrades.value() : 1; int upgrading = candidates.changingOsVersionTo(target.version()).size(); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringOsUpgrader.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringOsUpgrader.java index cb6c7683f23..a5ff7b82551 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringOsUpgrader.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringOsUpgrader.java @@ -2,21 +2,32 @@ package com.yahoo.vespa.hosted.provision.os; import com.yahoo.component.Version; +import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.NodeType; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.node.Agent; +import com.yahoo.vespa.hosted.provision.node.ClusterId; import com.yahoo.vespa.hosted.provision.node.filter.NodeListFilter; import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.logging.Logger; +import java.util.stream.Collectors; /** - * An upgrader that retires and deprovisions hosts on stale OS versions. + * An upgrader that retires and deprovisions hosts on stale OS versions. For hosts containing stateful clusters, this + * upgrader limits node retirement so that at most one group per cluster is affected at a time. * - * Used in clouds where hosts must be re-provisioned to upgrade their OS. + * Used in clouds where the host configuration (e.g. local disk) requires re-provisioning to upgrade OS. * * @author mpolden */ @@ -35,8 +46,8 @@ public class RetiringOsUpgrader extends OsUpgrader { public void upgradeTo(OsVersionTarget target) { NodeList allNodes = nodeRepository.nodes().list(); Instant now = nodeRepository.clock().instant(); - for (var candidate : candidates(now, target, allNodes)) { - deprovision(candidate, target.version(), now); + for (Node host : deprovisionable(now, target, allNodes)) { + deprovision(host, target.version(), now); } } @@ -45,18 +56,46 @@ public class RetiringOsUpgrader extends OsUpgrader { // No action needed in this implementation. } - /** Returns nodes that are candidates for upgrade */ - private NodeList candidates(Instant instant, OsVersionTarget target, NodeList allNodes) { + /** Returns nodes that can be deprovisioned at given instant */ + private List<Node> deprovisionable(Instant instant, OsVersionTarget target, NodeList allNodes) { NodeList nodes = allNodes.state(Node.State.active, Node.State.provisioned).nodeType(target.nodeType()); if (softRebuild) { - // Retire only hosts which do not have a replaceable root disk + // Consider only hosts which do not have a replaceable root disk nodes = nodes.not().replaceableRootDisk(); } - return nodes.not().deprovisioning() - .not().onOsVersion(target.version()) - .matching(node -> canUpgradeTo(target.version(), instant, node)) - .byIncreasingOsVersion() - .first(upgradeSlots(target, nodes.deprovisioning())); + // Retire hosts up to slot limit while ensuring that only one group is retired at a time + NodeList activeNodes = allNodes.state(Node.State.active); + Map<ClusterId, Set<ClusterSpec.Group>> retiringGroupsByCluster = groupsOf(activeNodes.retiring()); + int limit = upgradeSlots(target, nodes.deprovisioning()); + List<Node> result = new ArrayList<>(); + NodeList candidates = nodes.not().deprovisioning() + .not().onOsVersion(target.version()) + .matching(node -> canUpgradeTo(target.version(), instant, node)) + .byIncreasingOsVersion(); + for (Node host : candidates) { + if (result.size() == limit) break; + // For all clusters residing on this host: Determine if deprovisioning the host would imply retiring nodes + // in additional groups beyond those already having retired nodes. If true, defer deprovisioning the host + boolean canDeprovision = true; + Map<ClusterId, Set<ClusterSpec.Group>> groupsOnHost = groupsOf(activeNodes.childrenOf(host)); + for (var clusterAndGroups : groupsOnHost.entrySet()) { + Set<ClusterSpec.Group> groups = clusterAndGroups.getValue(); + Set<ClusterSpec.Group> retiringGroups = retiringGroupsByCluster.get(clusterAndGroups.getKey()); + if (retiringGroups != null && !groups.equals(retiringGroups)) { + canDeprovision = false; + break; + } + } + // Deprovision host and count all cluster groups on the host as being retired + if (canDeprovision) { + result.add(host); + groupsOnHost.forEach((cluster, groups) -> retiringGroupsByCluster.merge(cluster, groups, (oldVal, newVal) -> { + oldVal.addAll(newVal); + return oldVal; + })); + } + } + return Collections.unmodifiableList(result); } /** Upgrade given host by retiring and deprovisioning it */ @@ -68,4 +107,17 @@ public class RetiringOsUpgrader extends OsUpgrader { nodeRepository.nodes().upgradeOs(NodeListFilter.from(host), Optional.of(target)); } + /** Returns the stateful groups present on given nodes, grouped by their cluster ID */ + private static Map<ClusterId, Set<ClusterSpec.Group>> groupsOf(NodeList nodes) { + return nodes.stream() + .filter(node -> node.allocation().isPresent() && + node.allocation().get().membership().cluster().isStateful() && + node.allocation().get().membership().cluster().group().isPresent()) + .collect(Collectors.groupingBy(node -> new ClusterId(node.allocation().get().owner(), + node.allocation().get().membership().cluster().id()), + HashMap::new, + Collectors.mapping(n -> n.allocation().get().membership().cluster().group().get(), + Collectors.toCollection(HashSet::new)))); + } + } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializer.java index 6f325700401..2dea70825ee 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializer.java @@ -6,6 +6,7 @@ import com.yahoo.config.provision.IntRange; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterResources; import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.slime.ArrayTraverser; import com.yahoo.slime.Cursor; import com.yahoo.slime.Inspector; import com.yahoo.slime.ObjectTraverser; @@ -56,6 +57,7 @@ public class ApplicationSerializer { private static final String groupSizeKey = "groupSize"; private static final String requiredKey = "required"; private static final String suggestedKey = "suggested"; + private static final String suggestionsKey = "suggestionsKey"; private static final String clusterInfoKey = "clusterInfo"; private static final String bcpDeadlineKey = "bcpDeadline"; private static final String hostTTLKey = "hostTTL"; @@ -139,7 +141,9 @@ public class ApplicationSerializer { toSlime(cluster.maxResources(), clusterObject.setObject(maxResourcesKey)); toSlime(cluster.groupSize(), clusterObject.setObject(groupSizeKey)); clusterObject.setBool(requiredKey, cluster.required()); + // TODO(olaa): Remove 'suggested' once API clients migrate to suggestion list toSlime(cluster.suggested(), clusterObject.setObject(suggestedKey)); + toSlime(cluster.suggestions(), clusterObject.setArray(suggestionsKey)); toSlime(cluster.target(), clusterObject.setObject(targetKey)); if (! cluster.clusterInfo().isEmpty()) toSlime(cluster.clusterInfo(), clusterObject.setObject(clusterInfoKey)); @@ -156,12 +160,20 @@ public class ApplicationSerializer { intRangeFromSlime(clusterObject.field(groupSizeKey)), clusterObject.field(requiredKey).asBool(), autoscalingFromSlime(clusterObject.field(suggestedKey)), + suggestionsFromSlime(clusterObject.field(suggestionsKey)), autoscalingFromSlime(clusterObject.field(targetKey)), clusterInfoFromSlime(clusterObject.field(clusterInfoKey)), bcpGroupInfoFromSlime(clusterObject.field(bcpGroupInfoKey)), scalingEventsFromSlime(clusterObject.field(scalingEventsKey))); } + private static void toSlime(List<Autoscaling> suggestions, Cursor suggestionsArray) { + suggestions.forEach(suggestion -> { + var suggestionObject = suggestionsArray.addObject(); + toSlime(suggestion, suggestionObject); + }); + } + private static void toSlime(Autoscaling autoscaling, Cursor autoscalingObject) { autoscalingObject.setString(statusKey, toAutoscalingStatusCode(autoscaling.status())); autoscalingObject.setString(descriptionKey, autoscaling.description()); @@ -227,6 +239,13 @@ public class ApplicationSerializer { metricsObject.field(cpuCostPerQueryKey).asDouble()); } + private static List<Autoscaling> suggestionsFromSlime(Inspector suggestionsObject) { + var suggestions = new ArrayList<Autoscaling>(); + if (!suggestionsObject.valid()) return suggestions; + suggestionsObject.traverse((ArrayTraverser) (id, suggestion) -> suggestions.add(autoscalingFromSlime(suggestion))); + return suggestions; + } + private static Autoscaling autoscalingFromSlime(Inspector autoscalingObject) { if ( ! autoscalingObject.valid()) return Autoscaling.empty(); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/ApplicationSerializer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/ApplicationSerializer.java index 89853896104..0285e72a8a4 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/ApplicationSerializer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/ApplicationSerializer.java @@ -1,10 +1,14 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.provision.restapi; -import com.yahoo.config.provision.IntRange; +import com.yahoo.component.Version; +import com.yahoo.component.Vtag; +import com.yahoo.config.provision.CloudAccount; import com.yahoo.config.provision.ClusterResources; +import com.yahoo.config.provision.IntRange; import com.yahoo.slime.Cursor; import com.yahoo.slime.Slime; +import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.applications.Application; @@ -15,6 +19,7 @@ import com.yahoo.vespa.hosted.provision.autoscale.Limits; import com.yahoo.vespa.hosted.provision.autoscale.Load; import java.net.URI; +import java.util.Comparator; import java.util.List; /** @@ -40,6 +45,13 @@ public class ApplicationSerializer { URI applicationUri) { object.setString("url", applicationUri.toString()); object.setString("id", application.id().toFullString()); + Version version = applicationNodes.stream() + .map(node -> node.status().vespaVersion() + .orElse(node.allocation().get().membership().cluster().vespaVersion())) + .min(Comparator.naturalOrder()) + .orElse(Vtag.currentVersion); + object.setString("version", version.toFullString()); + object.setString("cloudAccount", applicationNodes.stream().findFirst().map(Node::cloudAccount).orElse(CloudAccount.empty).value()); clustersToSlime(application, applicationNodes, nodeRepository, object.setObject("clusters")); } @@ -66,13 +78,23 @@ public class ApplicationSerializer { if ( ! cluster.groupSize().isEmpty()) toSlime(cluster.groupSize(), clusterObject.setObject("groupSize")); toSlime(currentResources, clusterObject.setObject("current")); - if (cluster.shouldSuggestResources(currentResources)) + if (cluster.shouldSuggestResources(currentResources)) { toSlime(cluster.suggested(), clusterObject.setObject("suggested")); + toSlime(cluster.suggestions(), clusterObject.setArray("suggestions")); + + } toSlime(cluster.target(), clusterObject.setObject("target")); scalingEventsToSlime(cluster.scalingEvents(), clusterObject.setArray("scalingEvents")); clusterObject.setLong("scalingDuration", cluster.scalingDuration(nodes.clusterSpec()).toMillis()); } + private static void toSlime(List<Autoscaling> suggestions, Cursor autoscalingArray) { + suggestions.forEach(suggestion -> { + var autoscalingObject = autoscalingArray.addObject(); + toSlime(suggestion, autoscalingObject); + }); + } + private static void toSlime(Autoscaling autoscaling, Cursor autoscalingObject) { autoscalingObject.setString("status", autoscaling.status().name()); autoscalingObject.setString("description", autoscaling.description()); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodePatcher.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodePatcher.java index 93160bf7689..dc70af9a84f 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodePatcher.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodePatcher.java @@ -95,7 +95,7 @@ public class NodePatcher { } private void unifiedPatch(String hostname, InputStream json, boolean untrustedTenantHost) { - Inspector root = Exceptions.uncheck(() -> SlimeUtils.jsonToSlime(json.readAllBytes())).get(); + Inspector root = Exceptions.uncheck(() -> SlimeUtils.jsonToSlimeOrThrow(json.readAllBytes())).get(); Map<String, Inspector> fields = new HashMap<>(); root.traverse(fields::put); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java index d3b88997059..e7c9d1079fb 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java @@ -233,6 +233,14 @@ public class MockNodeRepository extends NodeRepository { Load.zero(), Load.zero(), Autoscaling.Metrics.zero())); + cluster1 = cluster1.withSuggestions(List.of(new Autoscaling(Autoscaling.Status.unavailable, + "", + Optional.of(new ClusterResources(6, 2, + new NodeResources(3, 20, 100, 1))), + clock().instant(), + Load.zero(), + Load.zero(), + Autoscaling.Metrics.zero()))); cluster1 = cluster1.withTarget(new Autoscaling(Autoscaling.Status.unavailable, "", Optional.of(new ClusterResources(4, 1, diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingTest.java index 4236f7ac968..830ff170a90 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingTest.java @@ -462,12 +462,12 @@ public class AutoscalingTest { fixture.tester().clock().advance(Duration.ofDays(2)); fixture.loader().applyLoad(new Load(0.01, 0.01, 0.01, 0, 0), 120); - Autoscaling suggestion = fixture.suggest(); + List<Autoscaling> suggestions = fixture.suggest(); fixture.tester().assertResources("Choosing the remote disk flavor as it has less disk", 2, 1, 3.0, 100.0, 10.0, - suggestion); + suggestions); assertEquals("Choosing the remote disk flavor as it has less disk", - StorageType.remote, suggestion.resources().get().nodeResources().storageType()); + StorageType.remote, suggestions.stream().findFirst().flatMap(Autoscaling::resources).get().nodeResources().storageType()); } @Test diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/Fixture.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/Fixture.java index df85ca4865f..4ce909fece3 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/Fixture.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/Fixture.java @@ -108,7 +108,7 @@ public class Fixture { } /** Compute an autoscaling suggestion for this. */ - public Autoscaling suggest() { + public List<Autoscaling> suggest() { return tester().suggest(applicationId, clusterSpec.id(), capacity.minResources(), capacity.maxResources()); } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainerTest.java index f8be27300fe..51297a88cad 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainerTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainerTest.java @@ -78,6 +78,12 @@ public class ScalingSuggestionsMaintainerTest { assertEquals("7 nodes with [vcpu: 4.1, memory: 5.3 Gb, disk: 16.5 Gb, bandwidth: 0.1 Gbps, architecture: any]", suggestionOf(app2, cluster2, tester).resources().get().toString()); + // Secondary suggestions + assertEquals("7 nodes with [vcpu: 3.7, memory: 4.5 Gb, disk: 10.0 Gb, bandwidth: 0.1 Gbps, architecture: any]", + suggestionsOf(app1, cluster1, tester).get(1).resources().get().toString()); + assertEquals("8 nodes with [vcpu: 3.6, memory: 4.7 Gb, disk: 14.2 Gb, bandwidth: 0.1 Gbps, architecture: any]", + suggestionsOf(app2, cluster2, tester).get(1).resources().get().toString()); + // Utilization goes way down tester.clock().advance(Duration.ofHours(13)); addMeasurements(0.10f, 0.10f, 0.10f, 0, 500, app1, tester.nodeRepository()); @@ -97,7 +103,7 @@ public class ScalingSuggestionsMaintainerTest { tester.clock().advance(Duration.ofDays(3)); addMeasurements(0.7f, 0.7f, 0.7f, 0, 500, app1, tester.nodeRepository()); maintainer.maintain(); - var suggested = tester.nodeRepository().applications().get(app1).get().cluster(cluster1.id()).get().suggested().resources().get(); + var suggested = tester.nodeRepository().applications().get(app1).get().cluster(cluster1.id()).get().suggestions().stream().findFirst().flatMap(Autoscaling::resources).get(); tester.deploy(app1, cluster1, Capacity.from(suggested, suggested, IntRange.empty(), false, true, Optional.empty(), ClusterInfo.empty())); tester.clock().advance(Duration.ofDays(2)); @@ -121,7 +127,11 @@ public class ScalingSuggestionsMaintainerTest { } private Autoscaling suggestionOf(ApplicationId app, ClusterSpec cluster, ProvisioningTester tester) { - return tester.nodeRepository().applications().get(app).get().cluster(cluster.id()).get().suggested(); + return suggestionsOf(app, cluster, tester).get(0); + } + + private List<Autoscaling> suggestionsOf(ApplicationId app, ClusterSpec cluster, ProvisioningTester tester) { + return tester.nodeRepository().applications().get(app).get().cluster(cluster.id()).get().suggestions(); } private boolean shouldSuggest(ApplicationId app, ClusterSpec cluster, ProvisioningTester tester) { diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java index dcbac44a37f..3f2d7112224 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java @@ -253,6 +253,63 @@ public class OsVersionsTest { } @Test + public void upgrade_by_retiring_is_limited_by_group_membership() { + var versions = new OsVersions(tester.nodeRepository(), Cloud.builder().dynamicProvisioning(true).build(), + Optional.ofNullable(tester.hostProvisioner())); + int hostCount = 7; + int app1GroupCount = 2; + setMaxActiveUpgrades(hostCount); + ApplicationId app1 = ApplicationId.from("t1", "a1", "i1"); + ApplicationId app2 = ApplicationId.from("t2", "a2", "i2"); + provisionInfraApplication(hostCount, NodeType.host); + deployApplication(app1, app1GroupCount); + deployApplication(app2); + Supplier<NodeList> hosts = () -> tester.nodeRepository().nodes().list() + .nodeType(NodeType.host) + .not().state(Node.State.deprovisioned); + + // All hosts are on initial version + var version0 = Version.fromString("8.0"); + versions.setTarget(NodeType.host, version0, false); + setCurrentVersion(hosts.get().asList(), version0); + + // New version is triggered + var version1 = Version.fromString("8.5"); + versions.setTarget(NodeType.host, version1, false); + versions.resumeUpgradeOf(NodeType.host, true); + { + // At most one node per group is retired + NodeList allNodes = tester.nodeRepository().nodes().list().not().state(Node.State.deprovisioned); + assertEquals(hostCount - 1, allNodes.nodeType(NodeType.host).deprovisioning().size()); + assertEquals(1, allNodes.owner(app1).retiring().group(0).size()); + assertEquals(0, allNodes.owner(app1).retiring().group(1).size()); + assertEquals(2, allNodes.owner(app2).retiring().size()); + + // Hosts complete reprovisioning + NodeList emptyHosts = allNodes.deprovisioning().nodeType(NodeType.host) + .matching(h -> allNodes.childrenOf(h).isEmpty()); + completeReprovisionOf(emptyHosts.asList(), NodeType.host); + replaceNodes(app1, app1GroupCount); + replaceNodes(app2); + completeReprovisionOf(hosts.get().deprovisioning().asList(), NodeType.host); + } + { + // Last host/group is retired + versions.resumeUpgradeOf(NodeType.host, true); + NodeList allNodes = tester.nodeRepository().nodes().list().not().state(Node.State.deprovisioned); + assertEquals(1, allNodes.nodeType(NodeType.host).deprovisioning().size()); + assertEquals(0, allNodes.owner(app1).retiring().group(0).size()); + assertEquals(1, allNodes.owner(app1).retiring().group(1).size()); + assertEquals(0, allNodes.owner(app2).retiring().size()); + replaceNodes(app1, app1GroupCount); + completeReprovisionOf(hosts.get().deprovisioning().asList(), NodeType.host); + } + NodeList allHosts = hosts.get(); + assertEquals(0, allHosts.deprovisioning().size()); + assertEquals(allHosts.size(), allHosts.onOsVersion(version1).size()); + } + + @Test public void upgrade_by_rebuilding() { var versions = new OsVersions(tester.nodeRepository(), Cloud.defaultCloud(), Optional.ofNullable(tester.hostProvisioner())); setMaxActiveUpgrades(1); @@ -547,24 +604,32 @@ public class OsVersionsTest { } private void deployApplication(ApplicationId application) { + deployApplication(application, 1); + } + + private void deployApplication(ApplicationId application, int groups) { ClusterSpec contentSpec = ClusterSpec.request(ClusterSpec.Type.content, ClusterSpec.Id.from("content1")).vespaVersion("7").build(); - List<HostSpec> hostSpecs = tester.prepare(application, contentSpec, 2, 1, new NodeResources(4, 8, 100, 0.3)); + List<HostSpec> hostSpecs = tester.prepare(application, contentSpec, 2, groups, new NodeResources(4, 8, 100, 0.3)); tester.activate(application, hostSpecs); } - private void replaceNodes(ApplicationId application) { + private void replaceNodes(ApplicationId application, int groups) { // Deploy to retire nodes - deployApplication(application); + deployApplication(application, groups); NodeList retired = tester.nodeRepository().nodes().list().owner(application).retired(); assertFalse("At least one node is retired", retired.isEmpty()); tester.nodeRepository().nodes().setRemovable(retired, false); // Redeploy to deactivate removable nodes and allocate new ones - deployApplication(application); + deployApplication(application, groups); tester.nodeRepository().nodes().list(Node.State.inactive).owner(application) .forEach(node -> tester.nodeRepository().nodes().removeRecursively(node, true)); } + private void replaceNodes(ApplicationId application) { + replaceNodes(application, 1); + } + private NodeList deprovisioningChildrenOf(Node parent) { return tester.nodeRepository().nodes().list() .childrenOf(parent) diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializerTest.java index 918a9043c93..90af6dca090 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializerTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializerTest.java @@ -41,6 +41,7 @@ public class ApplicationSerializerTest { IntRange.empty(), true, Autoscaling.empty(), + List.of(), Autoscaling.empty(), ClusterInfo.empty(), BcpGroupInfo.empty(), @@ -60,6 +61,14 @@ public class ApplicationSerializerTest { new Load(0.1, 0.2, 0.3, 0.4, 0.5), new Load(0.4, 0.5, 0.6, 0.7, 0.8), new Autoscaling.Metrics(0.7, 0.8, 0.9)), + List.of(new Autoscaling(Autoscaling.Status.unavailable, + "", + Optional.of(new ClusterResources(20, 10, + new NodeResources(0.5, 4, 14, 16))), + Instant.ofEpochMilli(1234L), + new Load(0.1, 0.2, 0.3, 0.4, 0.5), + new Load(0.4, 0.5, 0.6, 0.7, 0.8), + new Autoscaling.Metrics(0.7, 0.8, 0.9))), new Autoscaling(Autoscaling.Status.insufficient, "Autoscaling status", Optional.of(new ClusterResources(10, 5, @@ -98,6 +107,7 @@ public class ApplicationSerializerTest { assertEquals(originalCluster.groupSize(), serializedCluster.groupSize()); assertEquals(originalCluster.required(), serializedCluster.required()); assertEquals(originalCluster.suggested(), serializedCluster.suggested()); + assertEquals(originalCluster.suggestions(), serializedCluster.suggestions()); assertEquals(originalCluster.target(), serializedCluster.target()); assertEquals(originalCluster.clusterInfo(), serializedCluster.clusterInfo()); assertEquals(originalCluster.bcpGroupInfo(), serializedCluster.bcpGroupInfo()); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicProvisioningTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicProvisioningTester.java index be2b2ca896a..6b6ef49fa5d 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicProvisioningTester.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicProvisioningTester.java @@ -31,6 +31,7 @@ import java.util.List; import java.util.Optional; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; /** @@ -143,6 +144,7 @@ public class DynamicProvisioningTester { cluster.groupSize(), cluster.required(), cluster.suggested(), + cluster.suggestions(), cluster.target(), cluster.clusterInfo(), cluster.bcpGroupInfo(), @@ -165,7 +167,7 @@ public class DynamicProvisioningTester { nodeRepository().nodes().list(Node.State.active).owner(applicationId)); } - public Autoscaling suggest(ApplicationId applicationId, ClusterSpec.Id clusterId, + public List<Autoscaling> suggest(ApplicationId applicationId, ClusterSpec.Id clusterId, ClusterResources min, ClusterResources max) { Application application = nodeRepository().applications().get(applicationId).orElse(Application.empty(applicationId)) .withCluster(clusterId, false, Capacity.from(min, max)); @@ -199,6 +201,14 @@ public class DynamicProvisioningTester { public ClusterResources assertResources(String message, int nodeCount, int groupCount, double approxCpu, double approxMemory, double approxDisk, + List<Autoscaling> autoscaling) { + assertFalse(autoscaling.isEmpty()); + return assertResources(message, nodeCount, groupCount, approxCpu, approxMemory, approxDisk, autoscaling.get(0)); + } + + public ClusterResources assertResources(String message, + int nodeCount, int groupCount, + double approxCpu, double approxMemory, double approxDisk, Autoscaling autoscaling) { assertTrue("Resources should be present: " + message + " (" + autoscaling + ": " + autoscaling.status() + ")", autoscaling.resources().isPresent()); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application1.json b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application1.json index 7b2cf1dc8e4..e74e705e1aa 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application1.json +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application1.json @@ -1,6 +1,8 @@ { "url" : "http://localhost:8080/nodes/v2/applications/tenant1.application1.instance1", "id" : "tenant1.application1.instance1", + "cloudAccount": "aws:111222333444", + "version": "5.104.142", "clusters" : { "id1" : { "type": "container", @@ -80,6 +82,45 @@ "cpuCostPerQuery" : 0.0 } }, + "suggestions": [ + { + "at": 123, + "description": "", + "ideal": { + "cpu": 0.0, + "disk": 0.0, + "gpu": 0.0, + "gpuMemory": 0.0, + "memory": 0.0 + }, + "metrics": { + "cpuCostPerQuery": 0.0, + "growthRateHeadroom": 0.0, + "queryRate": 0.0 + }, + "peak": { + "cpu": 0.0, + "disk": 0.0, + "gpu": 0.0, + "gpuMemory": 0.0, + "memory": 0.0 + }, + "resources": { + "groups": 2, + "nodes": 6, + "resources": { + "architecture": "any", + "bandwidthGbps": 1.0, + "diskGb": 100.0, + "diskSpeed": "fast", + "memoryGb": 20.0, + "storageType": "any", + "vcpu": 3.0 + } + }, + "status": "unavailable" + } + ], "target" : { "status" : "unavailable", "description" : "", diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application2.json b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application2.json index 10173089f75..abd76f9be96 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application2.json +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application2.json @@ -1,6 +1,8 @@ { "url": "http://localhost:8080/nodes/v2/applications/tenant2.application2.instance2", "id": "tenant2.application2.instance2", + "cloudAccount": "aws:111222333444", + "version": "6.42.0", "clusters": { "id2": { "type": "content", diff --git a/parent/pom.xml b/parent/pom.xml index 82bf09ea92b..c10c74c0cb1 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -317,7 +317,7 @@ --> <groupId>org.openrewrite.maven</groupId> <artifactId>rewrite-maven-plugin</artifactId> - <version>5.20.0</version> + <version>5.21.0</version> <configuration> <activeRecipes> <recipe>org.openrewrite.java.testing.junit5.JUnit5BestPractices</recipe> @@ -327,7 +327,7 @@ <dependency> <groupId>org.openrewrite.recipe</groupId> <artifactId>rewrite-testing-frameworks</artifactId> - <version>2.3.0</version> + <version>2.3.1</version> </dependency> </dependencies> </plugin> @@ -1172,7 +1172,7 @@ See pluginManagement of rewrite-maven-plugin for more details --> <groupId>org.openrewrite.recipe</groupId> <artifactId>rewrite-recipe-bom</artifactId> - <version>2.6.2</version> + <version>2.6.3</version> <type>pom</type> <scope>import</scope> </dependency> diff --git a/screwdriver.yaml b/screwdriver.yaml index a8aef08d557..1a3bada2c42 100644 --- a/screwdriver.yaml +++ b/screwdriver.yaml @@ -537,34 +537,6 @@ jobs: - cleanup: | screwdriver/delete-old-cloudsmith-artifacts.sh - mirror-copr-rpms-to-artifactory: - image: docker.io/almalinux:8 - annotations: - screwdriver.cd/cpu: LOW - screwdriver.cd/ram: LOW - screwdriver.cd/disk: HIGH - screwdriver.cd/timeout: 60 - screwdriver.cd/buildPeriodically: H 6 * * * - secrets: - - JFROG_API_TOKEN - steps: - - install: | - dnf install -y dnf-plugins-core - - mirror: | - screwdriver/publish-unpublished-rpms-to-jfrog-cloud.sh - - delete-old-versions-on-artifactory: - annotations: - screwdriver.cd/cpu: LOW - screwdriver.cd/ram: LOW - screwdriver.cd/timeout: 10 - screwdriver.cd/buildPeriodically: H 6 * * 1 - secrets: - - JFROG_API_TOKEN - steps: - - cleanup: | - screwdriver/delete-old-artifactory-artifacts.sh - link-check: image: ruby:3.1 annotations: @@ -582,7 +554,7 @@ jobs: bundle exec jekyll build - check-links: | bundle exec htmlproofer \ - --assume-extension --check-html --check-external-hash --no-enforce-http \ + --assume-extension --check-html --no-check-external-hash --no-enforce-http \ --typhoeus '{"connecttimeout": 10, "timeout": 30, "followlocation": false}' \ --hydra '{"max_concurrency": 1}' \ --ignore-urls '/slack.vespa.ai/,/localhost:8080/,/127.0.0.1:3000/,/favicon.svg/,/main.jsx/' \ diff --git a/searchcore/src/tests/proton/documentmetastore/lid_allocator/lid_allocator_test.cpp b/searchcore/src/tests/proton/documentmetastore/lid_allocator/lid_allocator_test.cpp index e136e491f05..4aefa10f5f2 100644 --- a/searchcore/src/tests/proton/documentmetastore/lid_allocator/lid_allocator_test.cpp +++ b/searchcore/src/tests/proton/documentmetastore/lid_allocator/lid_allocator_test.cpp @@ -180,9 +180,10 @@ TEST_F(LidAllocatorTest, whitelist_blueprint_can_maximize_relative_estimate) activate_lids({ 1, 2, 3, 4 }, true); // the number of hits are overestimated based on the number of // documents that could be active (100 in this test fixture) - EXPECT_EQ(make_whitelist_blueprint(1000)->estimate(), 0.1); - EXPECT_EQ(make_whitelist_blueprint(200)->estimate(), 0.5); - EXPECT_EQ(make_whitelist_blueprint(5)->estimate(), 1.0); + // NOTE: optimize must be called in order to calculate the relative estimate + EXPECT_EQ(Blueprint::optimize(make_whitelist_blueprint(1000))->estimate(), 0.1); + EXPECT_EQ(Blueprint::optimize(make_whitelist_blueprint(200))->estimate(), 0.5); + EXPECT_EQ(Blueprint::optimize(make_whitelist_blueprint(5))->estimate(), 1.0); } class LidAllocatorPerformanceTest : public LidAllocatorTest, diff --git a/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp b/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp index 159ae339aa9..d2e649bb339 100644 --- a/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp +++ b/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp @@ -8,6 +8,7 @@ #include <string> #include <vector> #include <map> +#include <set> #include <initializer_list> const char *prog = "../../../apps/verify_ranksetup/vespa-verify-ranksetup-bin"; @@ -25,6 +26,8 @@ using search::index::schema::DataType; using vespalib::make_string_short::fmt; +enum class SearchMode { INDEXED, STREAMING, BOTH }; + struct Writer { FILE *file; explicit Writer(const std::string &file_name) { @@ -145,6 +148,39 @@ struct Setup { out.fmt("indexfield[%zu].collectiontype %s\n", i, pos->second.second.c_str()); } } + void write_vsmfield(const Writer &out, size_t idx, std::string name, std::string dataType) { + out.fmt("fieldspec[%zu].name \"%s\"\n", idx, name.c_str()); + if (dataType == "STRING") { + out.fmt("fieldspec[%zu].searchmethod AUTOUTF8\n", idx); + out.fmt("fieldspec[%zu].normalize LOWERCASE\n", idx); + } else { + out.fmt("fieldspec[%zu].searchmethod %s\n", idx, dataType.c_str()); + } + } + void write_vsmfields(const Writer &out) { + std::set<std::string> allFields; + size_t i = 0; + for (const auto & field : indexes) { + write_vsmfield(out, i, field.first, field.second.first); + out.fmt("fieldspec[%zu].fieldtype INDEX\n", i); + i++; + allFields.insert(field.first); + } + for (const auto & field : attributes) { + if (allFields.count(field.first) != 0) continue; + write_vsmfield(out, i, field.first, field.second.dataType); + out.fmt("fieldspec[%zu].fieldtype ATTRIBUTE\n", i); + i++; + allFields.insert(field.first); + } + out.fmt("documenttype[0].name \"foobar\"\n"); + size_t j = 0; + for (const auto & field : allFields) { + out.fmt("documenttype[0].index[%zu].name \"%s\"\n", j, field.c_str()); + out.fmt("documenttype[0].index[%zu].field[0].name \"%s\"\n", j, field.c_str()); + j++; + } + } void write_rank_profiles(const Writer &out) { out.fmt("rankprofile[%zu]\n", extra_profiles.size() + 1); out.fmt("rankprofile[0].name \"default\"\n"); @@ -165,7 +201,7 @@ struct Setup { for (const auto &entry: constants) { out.fmt("constant[%zu].name \"%s\"\n", idx, entry.first.c_str()); out.fmt("constant[%zu].fileref \"12345\"\n", idx); - out.fmt("constant[%zu].type \"%s\"\n", idx, entry.second.c_str()); + out.fmt("constant[%zu].type \"%s\"\n", idx, entry.second.c_str()); ++idx; } } @@ -215,32 +251,45 @@ struct Setup { void generate() { write_attributes(Writer(gen_dir + "/attributes.cfg")); write_indexschema(Writer(gen_dir + "/indexschema.cfg")); + write_vsmfields(Writer(gen_dir + "/vsmfields.cfg")); write_rank_profiles(Writer(gen_dir + "/rank-profiles.cfg")); write_ranking_constants(Writer(gen_dir + "/ranking-constants.cfg")); write_ranking_expressions(Writer(gen_dir + "/ranking-expressions.cfg")); write_onnx_models(Writer(gen_dir + "/onnx-models.cfg")); write_self_cfg(Writer(gen_dir + "/verify-ranksetup.cfg")); } - bool verify() { + bool verify(SearchMode mode = SearchMode::BOTH) { + if (mode == SearchMode::BOTH) { + bool res_indexed = verify_mode(SearchMode::INDEXED); + bool res_streaming = verify_mode(SearchMode::STREAMING); + EXPECT_EQUAL(res_indexed, res_streaming); + return res_indexed; + } else { + return verify_mode(mode); + } + } + bool verify_mode(SearchMode mode) { generate(); - vespalib::Process process(fmt("%s dir:%s", prog, gen_dir.c_str()), true); + vespalib::Process process(fmt("%s dir:%s%s", prog, gen_dir.c_str(), + (mode == SearchMode::STREAMING ? " -S" : "")), + true); for (auto line = process.read_line(); !line.empty(); line = process.read_line()) { fprintf(stderr, "> %s\n", line.c_str()); } return (process.join() == 0); } - void verify_valid(std::initializer_list<std::string> features) { + void verify_valid(std::initializer_list<std::string> features, SearchMode mode = SearchMode::BOTH) { for (const std::string &f: features) { first_phase(f); - if (!EXPECT_TRUE(verify())) { + if (!EXPECT_TRUE(verify(mode))) { fprintf(stderr, "--> feature '%s' was invalid (should be valid)\n", f.c_str()); } } } - void verify_invalid(std::initializer_list<std::string> features) { + void verify_invalid(std::initializer_list<std::string> features, SearchMode mode = SearchMode::BOTH) { for (const std::string &f: features) { first_phase(f); - if (!EXPECT_TRUE(!verify())) { + if (!EXPECT_TRUE(!verify(mode))) { fprintf(stderr, "--> feature '%s' was valid (should be invalid)\n", f.c_str()); } } @@ -346,12 +395,12 @@ TEST_F("require that dump features can break validation", SimpleSetup()) { //----------------------------------------------------------------------------- TEST_F("require that fieldMatch feature requires single value field", SimpleSetup()) { - f.verify_invalid({"fieldMatch(keywords)", "fieldMatch(list)"}); + f.verify_invalid({"fieldMatch(keywords)", "fieldMatch(list)"}, SearchMode::INDEXED); f.verify_valid({"fieldMatch(title)"}); } TEST_F("require that age feature requires attribute parameter", SimpleSetup()) { - f.verify_invalid({"age(unknown)", "age(title)"}); + f.verify_invalid({"age(unknown)", "age(title)"}, SearchMode::INDEXED); f.verify_valid({"age(date)"}); } @@ -361,7 +410,7 @@ TEST_F("require that nativeRank can be used on any valid field", SimpleSetup()) } TEST_F("require that nativeAttributeMatch requires attribute parameter", SimpleSetup()) { - f.verify_invalid({"nativeAttributeMatch(unknown)", "nativeAttributeMatch(title)", "nativeAttributeMatch(title,date)"}); + f.verify_invalid({"nativeAttributeMatch(unknown)", "nativeAttributeMatch(title)", "nativeAttributeMatch(title,date)"}, SearchMode::INDEXED); f.verify_valid({"nativeAttributeMatch", "nativeAttributeMatch(date)"}); } diff --git a/searchcore/src/vespa/searchcore/bmcluster/bm_node.cpp b/searchcore/src/vespa/searchcore/bmcluster/bm_node.cpp index e3315399ed9..ede5728026e 100644 --- a/searchcore/src/vespa/searchcore/bmcluster/bm_node.cpp +++ b/searchcore/src/vespa/searchcore/bmcluster/bm_node.cpp @@ -34,7 +34,6 @@ #include <vespa/storage/config/config-stor-bouncer.h> #include <vespa/storage/config/config-stor-communicationmanager.h> #include <vespa/storage/config/config-stor-distributormanager.h> -#include <vespa/storage/config/config-stor-opslogger.h> #include <vespa/storage/config/config-stor-prioritymapping.h> #include <vespa/storage/config/config-stor-server.h> #include <vespa/storage/config/config-stor-status.h> @@ -103,7 +102,6 @@ using vespa::config::content::core::BucketspacesConfigBuilder; using vespa::config::content::core::StorBouncerConfigBuilder; using vespa::config::content::core::StorCommunicationmanagerConfigBuilder; using vespa::config::content::core::StorDistributormanagerConfigBuilder; -using vespa::config::content::core::StorOpsloggerConfigBuilder; using vespa::config::content::core::StorPrioritymappingConfigBuilder; using vespa::config::content::core::StorServerConfigBuilder; using vespa::config::content::core::StorStatusConfigBuilder; @@ -275,7 +273,6 @@ struct StorageConfigSet StorDistributionConfigBuilder stor_distribution; StorBouncerConfigBuilder stor_bouncer; StorCommunicationmanagerConfigBuilder stor_communicationmanager; - StorOpsloggerConfigBuilder stor_opslogger; StorPrioritymappingConfigBuilder stor_prioritymapping; UpgradingConfigBuilder upgrading; StorServerConfigBuilder stor_server; @@ -292,7 +289,6 @@ struct StorageConfigSet stor_distribution(), stor_bouncer(), stor_communicationmanager(), - stor_opslogger(), stor_prioritymapping(), upgrading(), stor_server(), @@ -335,7 +331,6 @@ struct StorageConfigSet set.addBuilder(config_id, &stor_distribution); set.addBuilder(config_id, &stor_bouncer); set.addBuilder(config_id, &stor_communicationmanager); - set.addBuilder(config_id, &stor_opslogger); set.addBuilder(config_id, &stor_prioritymapping); set.addBuilder(config_id, &upgrading); set.addBuilder(config_id, &stor_server); diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp index f81b47583b9..30ba7d320f7 100644 --- a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp +++ b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp @@ -627,7 +627,7 @@ AttributeWriter::internalPut(SerialNum serialNum, const Document &doc, DocumentI bool allAttributes, OnWriteDoneType onWriteDone) { for (const auto &wc : _writeContexts) { - if (wc.use_two_phase_put()) { + if (allAttributes && wc.use_two_phase_put()) { assert(wc.getFields().size() == 1); wc.consider_build_field_paths(doc); auto prepare_task = std::make_unique<PreparePutTask>(serialNum, lid, wc, doc); diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_phase_limiter.cpp b/searchcore/src/vespa/searchcore/proton/matching/match_phase_limiter.cpp index 784ce649c5f..98c5daa1415 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_phase_limiter.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/match_phase_limiter.cpp @@ -102,9 +102,8 @@ do_limit(AttributeLimiter &limiter_factory, SearchIterator::UP search, double ma return search; } -// When hitrate is below 1% limiting the query is often far more expensive than not. -// TODO This limit should probably be a lot higher. -constexpr double MIN_HIT_RATE_LIMIT = 0.01; +// When hitrate is below 0.2% limiting the query is often far more expensive than not. +constexpr double MIN_HIT_RATE_LIMIT = 0.002; } // namespace proton::matching::<unnamed> diff --git a/searchcore/src/vespa/searchcore/proton/matching/query.cpp b/searchcore/src/vespa/searchcore/proton/matching/query.cpp index a93e8fbbddc..5ade0a44b8a 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/query.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/query.cpp @@ -200,8 +200,7 @@ Query::reserveHandles(const IRequestContext & requestContext, ISearchContext &co void Query::optimize(bool sort_by_cost) { - (void) sort_by_cost; - _blueprint = Blueprint::optimize(std::move(_blueprint), sort_by_cost); + _blueprint = Blueprint::optimize_and_sort(std::move(_blueprint), true, sort_by_cost); LOG(debug, "optimized blueprint:\n%s\n", _blueprint->asString().c_str()); } @@ -223,7 +222,7 @@ Query::handle_global_filter(const IRequestContext & requestContext, uint32_t doc } // optimized order may change after accounting for global filter: trace.addEvent(5, "Optimize query execution plan to account for global filter"); - _blueprint = Blueprint::optimize(std::move(_blueprint), sort_by_cost); + _blueprint = Blueprint::optimize_and_sort(std::move(_blueprint), true, sort_by_cost); LOG(debug, "blueprint after handle_global_filter:\n%s\n", _blueprint->asString().c_str()); // strictness may change if optimized order changed: fetchPostings(ExecuteInfo::create(true, 1.0, requestContext.getDoom(), requestContext.thread_bundle())); diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt index 5628db99171..0e97621f228 100644 --- a/searchlib/CMakeLists.txt +++ b/searchlib/CMakeLists.txt @@ -190,6 +190,7 @@ vespa_define_module( src/tests/postinglistbm src/tests/predicate src/tests/query + src/tests/query/streaming src/tests/queryeval src/tests/queryeval/blueprint src/tests/queryeval/dot_product @@ -203,6 +204,7 @@ vespa_define_module( src/tests/queryeval/monitoring_search_iterator src/tests/queryeval/multibitvectoriterator src/tests/queryeval/nearest_neighbor + src/tests/queryeval/or_speed src/tests/queryeval/parallel_weak_and src/tests/queryeval/predicate src/tests/queryeval/profiled_iterator diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java index 467a7860053..ed672c2dcd7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java @@ -11,11 +11,9 @@ import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; -import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.List; -import java.util.Optional; import java.util.Objects; /** @@ -26,7 +24,7 @@ import java.util.Objects; @Beta public class UnpackBitsNode extends CompositeNode { - private static String operationName = "unpack_bits"; + private static final String operationName = "unpack_bits"; private enum EndianNess { BIG_ENDIAN("big"), LITTLE_ENDIAN("little"); @@ -121,9 +119,9 @@ public class UnpackBitsNode extends CompositeNode { var dim = inputType.dimensions().get(i); if (dim.name().equals(meta.unpackDimension())) { long newIdx = oldAddr.numericLabel(i) * 8 + bitIdx; - addrBuilder.add(dim.name(), String.valueOf(newIdx)); + addrBuilder.add(dim.name(), newIdx); } else { - addrBuilder.add(dim.name(), oldAddr.label(i)); + addrBuilder.add(dim.name(), (int) oldAddr.numericLabel(i)); } } var newAddr = addrBuilder.build(); @@ -152,7 +150,6 @@ public class UnpackBitsNode extends CompositeNode { if (lastDim.size().isEmpty()) { throw new IllegalArgumentException("bad " + operationName + "; last indexed dimension must be bound, but type was: " + inputType); } - List<TensorType.Dimension> outputDims = new ArrayList<>(); var ttBuilder = new TensorType.Builder(targetCellType); for (var dim : inputType.dimensions()) { if (dim.name().equals(lastDim.name())) { diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 591f0eb8b37..97aa42f79c9 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -177,7 +177,7 @@ List<ReferenceNode> featureList() : ReferenceNode exp; } { - ( ( exp = feature() { ret.add(exp); } )+ <EOF> ) + ( ( exp = feature() { ret.add(exp); } )* <EOF> ) { return ret; } } diff --git a/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp b/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp index f612bdda87f..6e479e1d9db 100644 --- a/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp +++ b/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp @@ -430,7 +430,8 @@ BitVectorTest::test(BasicType bt, CollectionType ct, const vespalib::string &pre sc = getSearch<VectorType>(tv, filter); checkSearch(v, std::move(sc), 2, 1022, 205, !filter, true); const auto* dww = v->as_docid_with_weight_posting_store(); - if (dww != nullptr) { + if ((dww != nullptr) && (bt == BasicType::STRING)) { + // This way of doing lookup is only supported by string attributes. auto lres = dww->lookup(getSearchStr<VectorType>(), dww->get_dictionary_snapshot()); using DWSI = search::queryeval::DocidWithWeightSearchIterator; TermFieldMatchData md; diff --git a/searchlib/src/tests/attribute/dfa_fuzzy_matcher/dfa_fuzzy_matcher_test.cpp b/searchlib/src/tests/attribute/dfa_fuzzy_matcher/dfa_fuzzy_matcher_test.cpp index 433ad9e7671..8ba8c62c5ff 100644 --- a/searchlib/src/tests/attribute/dfa_fuzzy_matcher/dfa_fuzzy_matcher_test.cpp +++ b/searchlib/src/tests/attribute/dfa_fuzzy_matcher/dfa_fuzzy_matcher_test.cpp @@ -197,7 +197,7 @@ dfa_fuzzy_match_in_dictionary_no_skip(std::string_view target, const StringEnumS size_t seeks = 0; for (;itr.valid(); ++itr) { auto word = store.get_value(itr.getKey().load_relaxed()); - if (matcher.is_match(word)) { + if (matcher.is_match(std::string_view(word))) { ++matches; if (collect_matches) { matched_words.push_back(word); diff --git a/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp b/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp index 899ddaa3cc0..87b771af8e6 100644 --- a/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp +++ b/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp @@ -26,6 +26,7 @@ using LookupKey = IDirectPostingStore::LookupKey; struct IntegerKey : public LookupKey { int64_t _value; IntegerKey(int64_t value_in) : _value(value_in) {} + IntegerKey(const vespalib::string&) : _value() { abort(); } vespalib::stringref asString() const override { abort(); } bool asInteger(int64_t& value) const override { value = _value; return true; } }; @@ -33,6 +34,7 @@ struct IntegerKey : public LookupKey { struct StringKey : public LookupKey { vespalib::string _value; StringKey(int64_t value_in) : _value(std::to_string(value_in)) {} + StringKey(const vespalib::string& value_in) : _value(value_in) {} vespalib::stringref asString() const override { return _value; } bool asInteger(int64_t&) const override { abort(); } }; @@ -78,6 +80,10 @@ populate_attribute(AttributeType& attr, const std::vector<DataType>& values) for (auto docid : range(300, 128)) { attr.update(docid, values[3]); } + if (values.size() > 4) { + attr.update(40, values[4]); + attr.update(41, values[5]); + } attr.commit(true); } @@ -93,7 +99,7 @@ make_attribute(CollectionType col_type, BasicType type, bool field_is_filter) auto attr = test::AttributeBuilder(field_name, cfg).docs(num_docs).get(); if (type == BasicType::STRING) { populate_attribute<StringAttribute, vespalib::string>(dynamic_cast<StringAttribute&>(*attr), - {"1", "3", "100", "300"}); + {"1", "3", "100", "300", "foo", "Foo"}); } else { populate_attribute<IntegerAttribute, int64_t>(dynamic_cast<IntegerAttribute&>(*attr), {1, 3, 100, 300}); @@ -156,12 +162,17 @@ using MultiInBlueprintType = DirectMultiTermBlueprint<IDocidWithWeightPostingSto using SingleWSetBlueprintType = DirectMultiTermBlueprint<IDocidPostingStore, WeightedSetTermSearch>; using MultiWSetBlueprintType = DirectMultiTermBlueprint<IDocidWithWeightPostingStore, WeightedSetTermSearch>; +vespalib::string iterator_unpack_docid_and_weights = "search::queryeval::WeightedSetTermSearchImpl<(search::queryeval::UnpackType)0"; +vespalib::string iterator_unpack_docid = "search::queryeval::WeightedSetTermSearchImpl<(search::queryeval::UnpackType)1"; +vespalib::string iterator_unpack_none = "search::queryeval::WeightedSetTermSearchImpl<(search::queryeval::UnpackType)2"; + class DirectMultiTermBlueprintTest : public ::testing::TestWithParam<TestParam> { public: std::shared_ptr<AttributeVector> attr; bool in_operator; bool single_type; bool integer_type; + bool field_is_filter; std::shared_ptr<ComplexLeafBlueprint> blueprint; Blueprint::HitEstimate estimate; fef::TermFieldMatchData tfmd; @@ -171,6 +182,7 @@ public: in_operator(true), single_type(true), integer_type(true), + field_is_filter(false), blueprint(), tfmd(), tfmda() @@ -178,7 +190,8 @@ public: tfmda.add(&tfmd); } ~DirectMultiTermBlueprintTest() {} - void setup(bool field_is_filter, bool need_term_field_match_data) { + void setup(bool field_is_filter_in, bool need_term_field_match_data) { + field_is_filter = field_is_filter_in; attr = make_attribute(GetParam().col_type, GetParam().type, field_is_filter); in_operator = GetParam().op_type == OperatorType::In; single_type = GetParam().col_type == CollectionType::SINGLE; @@ -216,15 +229,16 @@ public: tfmd.tagAsNotNeeded(); } } - template <typename BlueprintType> - void add_term_helper(BlueprintType& b, int64_t term_value) { + template <typename BlueprintType, typename TermType> + void add_term_helper(BlueprintType& b, TermType term_value) { if (integer_type) { b.addTerm(IntegerKey(term_value), 1, estimate); } else { b.addTerm(StringKey(term_value), 1, estimate); } } - void add_term(int64_t term_value) { + template <typename TermType> + void add_term(TermType term_value) { if (single_type) { if (in_operator) { add_term_helper(dynamic_cast<SingleInBlueprintType&>(*blueprint), term_value); @@ -239,11 +253,24 @@ public: } } } - std::unique_ptr<SearchIterator> create_leaf_search() const { - return blueprint->createLeafSearch(tfmda, true); + void add_terms(const std::vector<int64_t>& term_values) { + for (auto value : term_values) { + add_term(value); + } } - vespalib::string multi_term_iterator() const { - return in_operator ? "search::attribute::MultiTermOrFilterSearchImpl" : "search::queryeval::WeightedSetTermSearchImpl"; + void add_terms(const std::vector<vespalib::string>& term_values) { + for (auto value : term_values) { + add_term(value); + } + } + std::unique_ptr<SearchIterator> create_leaf_search(bool strict = true) const { + return blueprint->createLeafSearch(tfmda, strict); + } + vespalib::string resolve_iterator_with_unpack() const { + if (in_operator) { + return iterator_unpack_docid; + } + return field_is_filter ? iterator_unpack_docid : iterator_unpack_docid_and_weights; } }; @@ -252,7 +279,7 @@ expect_hits(const Docids& exp_docids, SearchIterator& itr) { SimpleResult exp(exp_docids); SimpleResult act; - act.search(itr); + act.search(itr, doc_id_limit); EXPECT_EQ(exp, act); } @@ -284,61 +311,54 @@ INSTANTIATE_TEST_SUITE_P(DefaultInstantiation, TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_none_filter_field) { setup(false, true); - add_term(1); - add_term(3); + add_terms({1, 3}); auto itr = create_leaf_search(); - EXPECT_THAT(itr->asString(), StartsWith(multi_term_iterator())); + EXPECT_THAT(itr->asString(), StartsWith(resolve_iterator_with_unpack())); expect_hits({10, 30, 31}, *itr); } -TEST_P(DirectMultiTermBlueprintTest, bitvectors_used_instead_of_btree_iterators_for_none_filter_field) +TEST_P(DirectMultiTermBlueprintTest, bitvectors_used_instead_of_btree_iterators_for_in_operator) { setup(false, true); if (!in_operator) { return; } - add_term(1); - add_term(100); + add_terms({1, 100}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 2); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); - expect_or_child(*itr, 1, multi_term_iterator()); + expect_or_child(*itr, 1, iterator_unpack_docid); expect_hits(concat({10}, range(100, 128)), *itr); } -TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_instead_of_bitvectors_for_none_filter_field) +TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_instead_of_bitvectors_for_wset_operator) { setup(false, true); if (in_operator) { return; } - add_term(1); - add_term(100); + add_terms({1, 100}); auto itr = create_leaf_search(); - EXPECT_THAT(itr->asString(), StartsWith(multi_term_iterator())); + EXPECT_THAT(itr->asString(), StartsWith(iterator_unpack_docid_and_weights)); expect_hits(concat({10}, range(100, 128)), *itr); } TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_filter_field) { setup(true, true); - add_term(1); - add_term(3); - add_term(100); - add_term(300); + add_terms({1, 3, 100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 3); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); expect_or_child(*itr, 1, "search::BitVectorIteratorStrictT"); - expect_or_child(*itr, 2, multi_term_iterator()); + expect_or_child(*itr, 2, iterator_unpack_docid); expect_hits(concat({10, 30, 31}, concat(range(100, 128), range(300, 128))), *itr); } TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field) { setup(true, true); - add_term(100); - add_term(300); + add_terms({100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 2); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -346,36 +366,31 @@ TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field) expect_hits(concat(range(100, 128), range(300, 128)), *itr); } -TEST_P(DirectMultiTermBlueprintTest, or_filter_iterator_used_for_filter_field_when_ranking_not_needed) +TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_filter_field_when_ranking_not_needed) { setup(true, false); - add_term(1); - add_term(3); + add_terms({1, 3}); auto itr = create_leaf_search(); - EXPECT_THAT(itr->asString(), StartsWith("search::attribute::MultiTermOrFilterSearchImpl")); + EXPECT_THAT(itr->asString(), StartsWith(iterator_unpack_none)); expect_hits({10, 30, 31}, *itr); } -TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_or_filter_iterator_used_for_filter_field_when_ranking_not_needed) +TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_filter_field_when_ranking_not_needed) { setup(true, false); - add_term(1); - add_term(3); - add_term(100); - add_term(300); + add_terms({1, 3, 100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 3); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); expect_or_child(*itr, 1, "search::BitVectorIteratorStrictT"); - expect_or_child(*itr, 2, "search::attribute::MultiTermOrFilterSearchImpl"); + expect_or_child(*itr, 2, iterator_unpack_none); expect_hits(concat({10, 30, 31}, concat(range(100, 128), range(300, 128))), *itr); } TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field_when_ranking_not_needed) { setup(true, false); - add_term(100); - add_term(300); + add_terms({100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 2); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -383,4 +398,41 @@ TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field_when_ expect_hits(concat(range(100, 128), range(300, 128)), *itr); } +TEST_P(DirectMultiTermBlueprintTest, hash_filter_used_for_non_strict_iterator_with_10_or_more_terms) +{ + setup(true, true); + if (!single_type) { + return; + } + add_terms({1, 3, 3, 3, 3, 3, 3, 3, 3, 3}); + auto itr = create_leaf_search(false); + EXPECT_THAT(itr->asString(), StartsWith("search::attribute::MultiTermHashFilter")); + expect_hits({10, 30, 31}, *itr); +} + +TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_non_strict_iterator_with_9_or_less_terms) +{ + setup(true, true); + if (!single_type) { + return; + } + add_terms({1, 3, 3, 3, 3, 3, 3, 3, 3}); + auto itr = create_leaf_search(false); + EXPECT_THAT(itr->asString(), StartsWith(iterator_unpack_docid)); + expect_hits({10, 30, 31}, *itr); +} + +TEST_P(DirectMultiTermBlueprintTest, hash_filter_with_string_folding_used_for_non_strict_iterator) +{ + setup(true, true); + if (!single_type || integer_type) { + return; + } + // "foo" matches documents with "foo" (40) and "Foo" (41). + add_terms({"foo", "3", "3", "3", "3", "3", "3", "3", "3", "3"}); + auto itr = create_leaf_search(false); + EXPECT_THAT(itr->asString(), StartsWith("search::attribute::MultiTermHashFilter")); + expect_hits({30, 31, 40, 41}, *itr); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/attribute/direct_posting_store/direct_posting_store_test.cpp b/searchlib/src/tests/attribute/direct_posting_store/direct_posting_store_test.cpp index c1e12580559..2ddd211bd12 100644 --- a/searchlib/src/tests/attribute/direct_posting_store/direct_posting_store_test.cpp +++ b/searchlib/src/tests/attribute/direct_posting_store/direct_posting_store_test.cpp @@ -142,7 +142,11 @@ TEST(DirectPostingStoreApiTest, attributes_do_not_support_IDocidPostingStore_int TEST(DirectPostingStoreApiTest, attributes_support_IDocidWithWeightPostingStore_interface) { expect_docid_with_weight_posting_store(BasicType::INT64, CollectionType::WSET, true); + expect_docid_with_weight_posting_store(BasicType::INT32, CollectionType::WSET, true); expect_docid_with_weight_posting_store(BasicType::STRING, CollectionType::WSET, true); + expect_docid_with_weight_posting_store(BasicType::INT64, CollectionType::ARRAY, true); + expect_docid_with_weight_posting_store(BasicType::INT32, CollectionType::ARRAY, true); + expect_docid_with_weight_posting_store(BasicType::STRING, CollectionType::ARRAY, true); } TEST(DirectPostingStoreApiTest, attributes_do_not_support_IDocidWithWeightPostingStore_interface) { @@ -150,13 +154,11 @@ TEST(DirectPostingStoreApiTest, attributes_do_not_support_IDocidWithWeightPostin expect_not_docid_with_weight_posting_store(BasicType::INT64, CollectionType::ARRAY, false); expect_not_docid_with_weight_posting_store(BasicType::INT64, CollectionType::WSET, false); expect_not_docid_with_weight_posting_store(BasicType::INT64, CollectionType::SINGLE, true); - expect_not_docid_with_weight_posting_store(BasicType::INT64, CollectionType::ARRAY, true); expect_not_docid_with_weight_posting_store(BasicType::STRING, CollectionType::SINGLE, false); expect_not_docid_with_weight_posting_store(BasicType::STRING, CollectionType::ARRAY, false); expect_not_docid_with_weight_posting_store(BasicType::STRING, CollectionType::WSET, false); expect_not_docid_with_weight_posting_store(BasicType::STRING, CollectionType::SINGLE, true); - expect_not_docid_with_weight_posting_store(BasicType::STRING, CollectionType::ARRAY, true); - expect_not_docid_with_weight_posting_store(BasicType::INT32, CollectionType::WSET, true); + expect_not_docid_with_weight_posting_store(BasicType::DOUBLE, CollectionType::ARRAY, true); expect_not_docid_with_weight_posting_store(BasicType::DOUBLE, CollectionType::WSET, true); } diff --git a/searchlib/src/tests/attribute/searchable/attribute_searchable_adapter_test.cpp b/searchlib/src/tests/attribute/searchable/attribute_searchable_adapter_test.cpp index ecc03ac54c5..3b346601245 100644 --- a/searchlib/src/tests/attribute/searchable/attribute_searchable_adapter_test.cpp +++ b/searchlib/src/tests/attribute/searchable/attribute_searchable_adapter_test.cpp @@ -473,35 +473,6 @@ TEST("require that attribute dot product can produce no hits") { } } -TEST("require that direct attribute iterators work") { - for (int i = 0; i <= 0x3; ++i) { - bool fast_search = ((i & 0x1) != 0); - bool strict = ((i & 0x2) != 0); - MyAttributeManager attribute_manager = make_weighted_string_attribute_manager(fast_search); - SimpleStringTerm empty_node("notfoo", "", 0, Weight(1)); - Result empty_result = do_search(attribute_manager, empty_node, strict); - EXPECT_EQUAL(0u, empty_result.hits.size()); - SimpleStringTerm node("foo", "", 0, Weight(1)); - Result result = do_search(attribute_manager, node, strict); - if (fast_search) { - EXPECT_EQUAL(3u, result.est_hits); - EXPECT_TRUE(result.has_minmax); - EXPECT_EQUAL(100, result.min_weight); - EXPECT_EQUAL(1000, result.max_weight); - EXPECT_TRUE(result.iterator_dump.find("DocidWithWeightSearchIterator") != vespalib::string::npos); - } else { - EXPECT_EQUAL(num_docs, result.est_hits); - EXPECT_FALSE(result.has_minmax); - EXPECT_TRUE(result.iterator_dump.find("DocidWithWeightSearchIterator") == vespalib::string::npos); - } - ASSERT_EQUAL(3u, result.hits.size()); - EXPECT_FALSE(result.est_empty); - EXPECT_EQUAL(20u, result.hits[0].docid); - EXPECT_EQUAL(40u, result.hits[1].docid); - EXPECT_EQUAL(50u, result.hits[2].docid); - } -} - TEST("require that single weighted set turns filter on filter fields") { bool fast_search = true; bool strict = true; diff --git a/searchlib/src/tests/features/prod_features_test.cpp b/searchlib/src/tests/features/prod_features_test.cpp index 2966e7d4b07..472bb545f33 100644 --- a/searchlib/src/tests/features/prod_features_test.cpp +++ b/searchlib/src/tests/features/prod_features_test.cpp @@ -918,27 +918,40 @@ Test::testDistance() assert2DZDistance(static_cast<feature_t>(std::sqrt(250.0f)), "5:-5", 10, -20); assert2DZDistance(static_cast<feature_t>(std::sqrt(450.0f)), "5:-5", -10, -20); assert2DZDistance(static_cast<feature_t>(std::sqrt(850.0f)), "5:-5", -10, 20); - assert2DZDistance(static_cast<feature_t>(std::sqrt(250.0f)), "5:-5", 15, -20, 0x80000000); // 2^31 + assert2DZDistance(static_cast<feature_t>(std::sqrt(325.0f)), "5:-5", 15, -20, 0x80000000); // 2^31 } { // test 2D multi location (zcurve) - vespalib::string positions = "5:-5,35:0,5:40,35:-40"; - assert2DZDistance(static_cast<feature_t>(std::sqrt(425.0f)), positions, 10, 20, 0, 2); - assert2DZDistance(static_cast<feature_t>(std::sqrt(250.0f)), positions, 10, -20, 0, 0); - assert2DZDistance(static_cast<feature_t>(std::sqrt(450.0f)), positions, -10, -20, 0, 0); - assert2DZDistance(static_cast<feature_t>(std::sqrt(625.0f)), positions, -10, 20, 0, 2); - assert2DZDistance(static_cast<feature_t>(std::sqrt(250.0f)), positions, 15, -20, 0x80000000, 0); // 2^31 - assert2DZDistance(static_cast<feature_t>(std::sqrt(425.0f)), positions, 45, -20, 0x80000000, 1); // 2^31 + // note: "aspect" is ignored now, computed from "y", and cos(60 degrees) = 0.5 + vespalib::string positions = "5:59999995," "35:60000000," "5:60000040," "35:59999960"; + TEST_DO(assert2DZDistance(static_cast<feature_t>(0.0f), positions, 5, 59999995, 0, 0)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(0.0f), positions, 35, 60000000, 0x10000000, 1)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(0.0f), positions, 5, 60000040, 0x20000000, 2)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(0.0f), positions, 35, 59999960, 0x30000000, 3)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(std::sqrt(250.0f)), positions, 15, 59999980, 0x40000000, 0)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(std::sqrt(250.0f)), positions, -5, 59999980, 0x50000000, 0)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(std::sqrt(250.0f)), positions, 45, 59999985, 0x60000000, 1)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(std::sqrt(250.0f)), positions, 45, 60000015, 0x70000000, 1)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(std::sqrt(425.0f)), positions, 15, 60000020, 0x80000000, 2)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(std::sqrt(425.0f)), positions, -5, 60000020, 0x90000000, 2)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(std::sqrt(50.0f)), positions, 45, 59999955, 0xa0000000, 3)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(std::sqrt(50.0f)), positions, 45, 59999965, 0xb0000000, 3)); + + TEST_DO(assert2DZDistance(static_cast<feature_t>(std::sqrt(450.0f)), positions, -25, 59999980, 0xc0000000, 0)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(std::sqrt(625.0f)), positions, -25, 60000060, 0xd0000000, 2)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(std::sqrt(250.0f)), positions, 15, 59999980, 0xe0000000, 0)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(std::sqrt(425.0f)), positions, 45, 59999980, 0xf0000000, 1)); } { // test geo multi location (zcurve) - vespalib::string positions = "0:0,100:100,-200:200,-300:-300,400:-400"; - assert2DZDistance(static_cast<feature_t>(0.0f), positions, 0, 0, 0x40000000, 0); - assert2DZDistance(static_cast<feature_t>(1.0f), positions, 100, 101, 0x40000000, 1); - assert2DZDistance(static_cast<feature_t>(0.0f), positions, -200, 200, 0x40000000, 2); - assert2DZDistance(static_cast<feature_t>(13.0f), positions, -320, -312, 0x40000000, 3); - assert2DZDistance(static_cast<feature_t>(5.0f), positions, 416, -403, 0x40000000, 4); - assert2DZDistance(static_cast<feature_t>(5.0f), positions, 112, 104, 0x40000000, 1); + // note: cos(70.528779 degrees) = 1/3 + vespalib::string positions = "0:70528779," "100:70528879," "-200:70528979," "-300:70528479," "400:70528379"; + TEST_DO(assert2DZDistance(static_cast<feature_t>(0.0f), positions, 0, 70528779 + 0, 0, 0)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(1.0f), positions, 100, 70528779 + 101, 0x20000000, 1)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(0.0f), positions, -200, 70528779 + 200, 0x40000000, 2)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(13.0f), positions, -315, 70528779 -312, 0x80000000, 3)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(5.0f), positions, 412, 70528779 -403, 0xB0000000, 4)); + TEST_DO(assert2DZDistance(static_cast<feature_t>(5.0f), positions, 109, 70528779 + 104, 0xF0000000, 1)); } { // test default distance @@ -1031,15 +1044,15 @@ Test::assert2DZDistance(feature_t exp, const vespalib::string & positions, GeoLocation::Aspect aspect{xAspect}; ft.getQueryEnv().addLocation(GeoLocationSpec{"pos", {p, aspect}}); ASSERT_TRUE(ft.setup()); - ASSERT_TRUE(ft.execute(RankResult().setEpsilon(1e-4). + EXPECT_TRUE(ft.execute(RankResult().setEpsilon(1e-4). addScore("distance(pos)", exp))); - ASSERT_TRUE(ft.execute(RankResult().setEpsilon(1e-4). + EXPECT_TRUE(ft.execute(RankResult().setEpsilon(1e-4). addScore("distance(pos).km", exp * 0.00011119508023))); - ASSERT_TRUE(ft.execute(RankResult().setEpsilon(1e-30). + EXPECT_TRUE(ft.execute(RankResult().setEpsilon(1e-30). addScore("distance(pos).index", hit_index))); - ASSERT_TRUE(ft.execute(RankResult().setEpsilon(1e-9). + EXPECT_TRUE(ft.execute(RankResult().setEpsilon(1e-9). addScore("distance(pos).latitude", pos[hit_index].second * 1e-6))); - ASSERT_TRUE(ft.execute(RankResult().setEpsilon(1e-9). + EXPECT_TRUE(ft.execute(RankResult().setEpsilon(1e-9). addScore("distance(pos).longitude", pos[hit_index].first * 1e-6))); } diff --git a/searchlib/src/tests/nearsearch/nearsearch_test.cpp b/searchlib/src/tests/nearsearch/nearsearch_test.cpp index 3751fc93cea..95701e59444 100644 --- a/searchlib/src/tests/nearsearch/nearsearch_test.cpp +++ b/searchlib/src/tests/nearsearch/nearsearch_test.cpp @@ -215,7 +215,7 @@ bool Test::testNearSearch(MyQuery &query, uint32_t matchId) { LOG(info, "testNearSearch(%d)", matchId); - search::queryeval::IntermediateBlueprint *near_b = 0; + search::queryeval::IntermediateBlueprint *near_b = nullptr; if (query.isOrdered()) { near_b = new search::queryeval::ONearBlueprint(query.getWindow()); } else { @@ -228,9 +228,10 @@ Test::testNearSearch(MyQuery &query, uint32_t matchId) layout.allocTermField(fieldId); near_b->addChild(query.getTerm(i).make_blueprint(fieldId, i)); } - search::fef::MatchData::UP md(layout.createMatchData()); - + bp->setDocIdLimit(1000); + bp = search::queryeval::Blueprint::optimize_and_sort(std::move(bp), true, true); bp->fetchPostings(search::queryeval::ExecuteInfo::TRUE); + search::fef::MatchData::UP md(layout.createMatchData()); search::queryeval::SearchIterator::UP near = bp->createSearch(*md, true); near->initFullRange(); bool foundMatch = false; diff --git a/searchlib/src/tests/query/streaming/CMakeLists.txt b/searchlib/src/tests/query/streaming/CMakeLists.txt new file mode 100644 index 00000000000..7568e45d00a --- /dev/null +++ b/searchlib/src/tests/query/streaming/CMakeLists.txt @@ -0,0 +1,46 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +vespa_add_executable(searchlib_query_streaming_hit_iterator_test_app TEST + SOURCES + hit_iterator_test.cpp + DEPENDS + searchlib + GTest::gtest +) +vespa_add_test(NAME searchlib_query_streaming_hit_iterator_test_app COMMAND searchlib_query_streaming_hit_iterator_test_app) + +vespa_add_executable(searchlib_query_streaming_hit_iterator_pack_test_app TEST + SOURCES + hit_iterator_pack_test.cpp + DEPENDS + searchlib + GTest::gtest +) +vespa_add_test(NAME searchlib_query_streaming_hit_iterator_pack_test_app COMMAND searchlib_query_streaming_hit_iterator_pack_test_app) + +vespa_add_executable(searchlib_query_streaming_near_test_app TEST + SOURCES + near_test.cpp + DEPENDS + searchlib + GTest::gtest +) +vespa_add_test(NAME searchlib_query_streaming_near_test_app COMMAND searchlib_query_streaming_near_test_app) + +vespa_add_executable(searchlib_query_streaming_same_element_query_node_test_app TEST + SOURCES + same_element_query_node_test.cpp + DEPENDS + searchlib + GTest::gtest +) +vespa_add_test(NAME searchlib_query_streaming_same_element_query_node_test_app COMMAND searchlib_query_streaming_same_element_query_node_test_app) + +vespa_add_executable(searchlib_query_streaming_phrase_query_node_test_app TEST + SOURCES + phrase_query_node_test.cpp + DEPENDS + searchlib + GTest::gtest +) +vespa_add_test(NAME searchlib_query_streaming_phrase_query_node_test_app COMMAND searchlib_query_streaming_phrase_query_node_test_app) diff --git a/searchlib/src/tests/query/streaming/hit_iterator_pack_test.cpp b/searchlib/src/tests/query/streaming/hit_iterator_pack_test.cpp new file mode 100644 index 00000000000..7d7d8307920 --- /dev/null +++ b/searchlib/src/tests/query/streaming/hit_iterator_pack_test.cpp @@ -0,0 +1,44 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchlib/query/streaming/hit_iterator_pack.h> +#include <vespa/vespalib/gtest/gtest.h> + +using search::streaming::HitIterator; +using search::streaming::HitIteratorPack; +using search::streaming::QueryNodeList; +using search::streaming::QueryTerm; +using search::streaming::QueryNodeResultBase; + +using FieldElement = HitIterator::FieldElement; + +TEST(HitIteratorPackTest, seek_to_matching_field_element) +{ + QueryNodeList qnl; + auto qt = std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "7", "", QueryTerm::Type::WORD); + qt->add(11, 0, 10, 0); + qt->add(11, 0, 10, 5); + qt->add(11, 1, 12, 0); + qt->add(11, 1, 12, 0); + qt->add(12, 1, 13, 0); + qt->add(12, 1, 13, 0); + qnl.emplace_back(std::move(qt)); + qt = std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "8", "", QueryTerm::Type::WORD); + qt->add(2, 0, 4, 0); + qt->add(11, 0, 10, 0); + qt->add(12, 1, 13, 0); + qt->add(12, 2, 14, 0); + qnl.emplace_back(std::move(qt)); + HitIteratorPack itr_pack(qnl); + EXPECT_TRUE(itr_pack.all_valid()); + EXPECT_TRUE(itr_pack.seek_to_matching_field_element()); + EXPECT_EQ(FieldElement(11, 0), itr_pack.get_field_element_ref()); + EXPECT_TRUE(itr_pack.seek_to_matching_field_element()); + EXPECT_EQ(FieldElement(11, 0), itr_pack.get_field_element_ref()); + ++itr_pack.get_field_element_ref().second; + EXPECT_TRUE(itr_pack.seek_to_matching_field_element()); + EXPECT_EQ(FieldElement(12, 1), itr_pack.get_field_element_ref()); + ++itr_pack.get_field_element_ref().second; + EXPECT_FALSE(itr_pack.seek_to_matching_field_element()); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/query/streaming/hit_iterator_test.cpp b/searchlib/src/tests/query/streaming/hit_iterator_test.cpp new file mode 100644 index 00000000000..a9588ea3d6c --- /dev/null +++ b/searchlib/src/tests/query/streaming/hit_iterator_test.cpp @@ -0,0 +1,122 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchlib/query/streaming/hit_iterator.h> +#include <vespa/vespalib/gtest/gtest.h> + +using search::streaming::Hit; +using search::streaming::HitList; +using search::streaming::HitIterator; + +using FieldElement = HitIterator::FieldElement; + +namespace { + +HitList +make_hit_list() +{ + HitList hl; + hl.emplace_back(11, 0, 10, 0); + hl.emplace_back(11, 0, 10, 5); + hl.emplace_back(11, 1, 12, 0); + hl.emplace_back(11, 1, 12, 7); + hl.emplace_back(12, 1, 13, 0); + hl.emplace_back(12, 1, 13, 9); + return hl; +} + +void +check_seek_to_field_elem(HitIterator& it, const FieldElement& field_element, const Hit* exp_ptr, const vespalib::string& label) +{ + SCOPED_TRACE(label); + EXPECT_TRUE(it.seek_to_field_element(field_element)); + EXPECT_TRUE(it.valid()); + EXPECT_EQ(exp_ptr, &*it); +} + +void +check_seek_to_field_elem_failure(HitIterator& it, const FieldElement& field_element, const vespalib::string& label) +{ + SCOPED_TRACE(label); + EXPECT_FALSE(it.seek_to_field_element(field_element)); + EXPECT_FALSE(it.valid()); +} + +void +check_step_in_field_element(HitIterator& it, FieldElement& field_element, bool exp_success, const Hit* exp_ptr, const vespalib::string& label) +{ + SCOPED_TRACE(label); + EXPECT_EQ(exp_success, it.step_in_field_element(field_element)); + if (exp_ptr) { + EXPECT_TRUE(it.valid()); + EXPECT_EQ(it.get_field_element(), field_element); + EXPECT_EQ(exp_ptr, &*it); + } else { + EXPECT_FALSE(it.valid()); + } +} + +void +check_seek_in_field_element(HitIterator& it, uint32_t position, FieldElement& field_element, bool exp_success, const Hit* exp_ptr, const vespalib::string& label) +{ + SCOPED_TRACE(label); + EXPECT_EQ(exp_success, it.seek_in_field_element(position, field_element)); + if (exp_ptr) { + EXPECT_TRUE(it.valid()); + EXPECT_EQ(it.get_field_element(), field_element); + EXPECT_EQ(exp_ptr, &*it); + } else { + EXPECT_FALSE(it.valid()); + } +} + +} + +TEST(HitITeratorTest, seek_to_field_element) +{ + auto hl = make_hit_list(); + HitIterator it(hl); + EXPECT_TRUE(it.valid()); + EXPECT_EQ(&hl[0], &*it); + check_seek_to_field_elem(it, FieldElement(0, 0), &hl[0], "(0, 0)"); + check_seek_to_field_elem(it, FieldElement(11, 0), &hl[0], "(11, 0)"); + check_seek_to_field_elem(it, FieldElement(11, 1), &hl[2], "(11, 1)"); + check_seek_to_field_elem(it, FieldElement(11, 2), &hl[4], "(11, 2)"); + check_seek_to_field_elem(it, FieldElement(12, 0), &hl[4], "(12, 0)"); + check_seek_to_field_elem(it, FieldElement(12, 1), &hl[4], "(12, 1)"); + check_seek_to_field_elem_failure(it, FieldElement(12, 2), "(12, 2)"); + check_seek_to_field_elem_failure(it, FieldElement(13, 0), "(13, 0)"); +} + +TEST(HitIteratorTest, step_in_field_element) +{ + auto hl = make_hit_list(); + HitIterator it(hl); + auto field_element = it.get_field_element(); + check_step_in_field_element(it, field_element, true, &hl[1], "1"); + check_step_in_field_element(it, field_element, false, &hl[2], "2"); + check_step_in_field_element(it, field_element, true, &hl[3], "3"); + check_step_in_field_element(it, field_element, false, &hl[4], "4"); + check_step_in_field_element(it, field_element, true, &hl[5], "5"); + check_step_in_field_element(it, field_element, false, nullptr, "end"); +} + +TEST(hitIteratorTest, seek_in_field_elem) +{ + auto hl = make_hit_list(); + HitIterator it(hl); + auto field_element = it.get_field_element(); + check_seek_in_field_element(it, 0, field_element, true, &hl[0], "0a"); + check_seek_in_field_element(it, 2, field_element, true, &hl[1], "2"); + check_seek_in_field_element(it, 5, field_element, true, &hl[1], "5"); + check_seek_in_field_element(it, 6, field_element, false, &hl[2], "6"); + check_seek_in_field_element(it, 0, field_element, true, &hl[2], "0b"); + check_seek_in_field_element(it, 1, field_element, true, &hl[3], "1"); + check_seek_in_field_element(it, 7, field_element, true, &hl[3], "7"); + check_seek_in_field_element(it, 8, field_element, false, &hl[4], "8"); + check_seek_in_field_element(it, 0, field_element, true, &hl[4], "0c"); + check_seek_in_field_element(it, 3, field_element, true, &hl[5], "3"); + check_seek_in_field_element(it, 9, field_element, true, &hl[5], "9"); + check_seek_in_field_element(it, 10, field_element, false, nullptr, "end"); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/query/streaming/near_test.cpp b/searchlib/src/tests/query/streaming/near_test.cpp new file mode 100644 index 00000000000..2f95eb13dbd --- /dev/null +++ b/searchlib/src/tests/query/streaming/near_test.cpp @@ -0,0 +1,185 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchlib/query/streaming/near_query_node.h> +#include <vespa/searchlib/query/streaming/onear_query_node.h> +#include <vespa/searchlib/query/streaming/queryterm.h> +#include <vespa/searchlib/query/tree/querybuilder.h> +#include <vespa/searchlib/query/tree/simplequery.h> +#include <vespa/searchlib/query/tree/stackdumpcreator.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/stllike/asciistream.h> +#include <ostream> +#include <tuple> + +using TestHit = std::tuple<uint32_t, uint32_t, int32_t, uint32_t>; + +using search::query::QueryBuilder; +using search::query::Node; +using search::query::SimpleQueryNodeTypes; +using search::query::StackDumpCreator; +using search::query::Weight; +using search::streaming::NearQueryNode; +using search::streaming::ONearQueryNode; +using search::streaming::Query; +using search::streaming::QueryNodeResultFactory; +using search::streaming::QueryTerm; +using search::streaming::QueryTermList; + +class TestParam { + bool _ordered; +public: + TestParam(bool ordered_in) + : _ordered(ordered_in) + { + } + bool ordered() const noexcept { return _ordered; } +}; + +std::ostream& operator<<(std::ostream& os, const TestParam& param) +{ + os << (param.ordered() ? "onear" : "near"); + return os; + +} +class NearTest : public ::testing::TestWithParam<TestParam> { +public: + NearTest(); + ~NearTest(); + bool evaluate_query(uint32_t distance, const std::vector<std::vector<TestHit>> &hitsvv); +}; + +NearTest::NearTest() + : ::testing::TestWithParam<TestParam>() +{ +} + +NearTest::~NearTest() = default; + +bool +NearTest::evaluate_query(uint32_t distance, const std::vector<std::vector<TestHit>> &hitsvv) +{ + QueryBuilder<SimpleQueryNodeTypes> builder; + if (GetParam().ordered()) { + builder.addONear(hitsvv.size(), distance); + } else { + builder.addNear(hitsvv.size(), distance); + } + for (uint32_t idx = 0; idx < hitsvv.size(); ++idx) { + vespalib::asciistream s; + s << "s" << idx; + builder.addStringTerm(s.str(), "field", idx, Weight(0)); + } + auto node = builder.build(); + vespalib::string stackDump = StackDumpCreator::create(*node); + QueryNodeResultFactory empty; + auto q = std::make_unique<Query>(empty, stackDump); + if (GetParam().ordered()) { + auto& top = dynamic_cast<ONearQueryNode&>(q->getRoot()); + EXPECT_EQ(hitsvv.size(), top.size()); + } else { + auto& top = dynamic_cast<NearQueryNode&>(q->getRoot()); + EXPECT_EQ(hitsvv.size(), top.size()); + } + QueryTermList terms; + q->getLeaves(terms); + EXPECT_EQ(hitsvv.size(), terms.size()); + for (QueryTerm * qt : terms) { + qt->resizeFieldId(1); + } + for (uint32_t idx = 0; idx < hitsvv.size(); ++idx) { + auto& hitsv = hitsvv[idx]; + auto& term = terms[idx]; + for (auto& hit : hitsv) { + term->add(std::get<0>(hit), std::get<1>(hit), std::get<2>(hit), std::get<3>(hit)); + } + } + return q->getRoot().evaluate(); +} + +TEST_P(NearTest, test_empty_near) +{ + EXPECT_FALSE(evaluate_query(4, { })); +} + +TEST_P(NearTest, test_near_success) +{ + EXPECT_TRUE(evaluate_query(4, { { { 0, 0, 10, 0} }, + { { 0, 0, 10, 2} }, + { { 0, 0, 10, 4} } })); +} + +TEST_P(NearTest, test_near_fail_distance_exceeded_first_term) +{ + EXPECT_FALSE(evaluate_query(4, { { { 0, 0, 10, 0} }, + { { 0, 0, 10, 2} }, + { { 0, 0, 10, 5} } })); +} + +TEST_P(NearTest, test_near_fail_distance_exceeded_second_term) +{ + EXPECT_FALSE(evaluate_query(4, { { { 0, 0, 10, 2} }, + { { 0, 0, 10, 0} }, + { { 0, 0, 10, 5} } })); +} + +TEST_P(NearTest, test_near_fail_element) +{ + EXPECT_FALSE(evaluate_query(4, { { { 0, 0, 10, 0} }, + { { 0, 0, 10, 2} }, + { { 0, 1, 10, 4} } })); +} + +TEST_P(NearTest, test_near_fail_field) +{ + EXPECT_FALSE(evaluate_query(4, { { { 0, 0, 10, 0} }, + { { 0, 0, 10, 2} }, + { { 1, 0, 10, 4} } })); +} + +TEST_P(NearTest, test_near_success_after_step_first_term) +{ + EXPECT_TRUE(evaluate_query(4, { { { 0, 0, 10, 0}, { 0, 0, 10, 2} }, + { { 0, 0, 10, 3} }, + { { 0, 0, 10, 5} } })); +} + +TEST_P(NearTest, test_near_success_after_step_second_term) +{ + EXPECT_TRUE(evaluate_query(4, { { { 0, 0, 10, 2} }, + { { 0, 0, 10, 0}, {0, 0, 10, 3} }, + { { 0, 0, 10, 5} } })); +} + +TEST_P(NearTest, test_near_success_in_second_element) +{ + EXPECT_TRUE(evaluate_query(4, { { { 0, 0, 10, 0}, { 0, 1, 10, 0} }, + { { 0, 0, 10, 2}, { 0, 1, 10, 2} }, + { { 0, 0, 10, 5}, { 0, 1, 10, 4} } })); +} + +TEST_P(NearTest, test_near_success_in_second_field) +{ + EXPECT_TRUE(evaluate_query(4, { { { 0, 0, 10, 0}, { 1, 0, 10, 0} }, + { { 0, 0, 10, 2}, { 1, 0, 10, 2} }, + { { 0, 0, 10, 5}, { 1, 0, 10, 4} } })); +} + +TEST_P(NearTest, test_order_might_matter) +{ + EXPECT_EQ(!GetParam().ordered(), evaluate_query(4, { { { 0, 0, 10, 2} }, + { { 0, 0, 10, 0} }, + { { 0, 0, 10, 4} } })); +} + +TEST_P(NearTest, test_overlap_might_matter) +{ + EXPECT_EQ(!GetParam().ordered(), evaluate_query(4, { { { 0, 0, 10, 0} }, + { { 0, 0, 10, 0} }, + { { 0, 0, 10, 4} } })); +} + +auto test_values = ::testing::Values(TestParam(false), TestParam(true)); + +INSTANTIATE_TEST_SUITE_P(NearTests, NearTest, test_values, testing::PrintToStringParamName()); + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/query/streaming/phrase_query_node_test.cpp b/searchlib/src/tests/query/streaming/phrase_query_node_test.cpp new file mode 100644 index 00000000000..5caae8d6e97 --- /dev/null +++ b/searchlib/src/tests/query/streaming/phrase_query_node_test.cpp @@ -0,0 +1,95 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchlib/query/streaming/phrase_query_node.h> +#include <vespa/searchlib/query/streaming/queryterm.h> +#include <vespa/searchlib/query/tree/querybuilder.h> +#include <vespa/searchlib/query/tree/simplequery.h> +#include <vespa/searchlib/query/tree/stackdumpcreator.h> +#include <vespa/vespalib/gtest/gtest.h> + +using search::query::QueryBuilder; +using search::query::Node; +using search::query::SimpleQueryNodeTypes; +using search::query::StackDumpCreator; +using search::query::Weight; +using search::streaming::HitList; +using search::streaming::PhraseQueryNode; +using search::streaming::Query; +using search::streaming::QueryTerm; +using search::streaming::QueryNodeRefList; +using search::streaming::QueryNodeResultFactory; +using search::streaming::QueryTermList; + +TEST(PhraseQueryNodeTest, test_phrase_evaluate) +{ + QueryBuilder<SimpleQueryNodeTypes> builder; + builder.addPhrase(3, "", 0, Weight(0)); + { + builder.addStringTerm("a", "", 0, Weight(0)); + builder.addStringTerm("b", "", 0, Weight(0)); + builder.addStringTerm("c", "", 0, Weight(0)); + } + Node::UP node = builder.build(); + vespalib::string stackDump = StackDumpCreator::create(*node); + QueryNodeResultFactory empty; + Query q(empty, stackDump); + QueryNodeRefList phrases; + q.getPhrases(phrases); + QueryTermList terms; + q.getLeaves(terms); + for (QueryTerm * qt : terms) { + qt->resizeFieldId(1); + } + + // field 0 + terms[0]->add(0, 0, 1, 0); + terms[1]->add(0, 0, 1, 1); + terms[2]->add(0, 0, 1, 2); + terms[0]->add(0, 0, 1, 7); + terms[1]->add(0, 1, 1, 8); + terms[2]->add(0, 0, 1, 9); + // field 1 + terms[0]->add(1, 0, 1, 4); + terms[1]->add(1, 0, 1, 5); + terms[2]->add(1, 0, 1, 6); + // field 2 (not complete match) + terms[0]->add(2, 0, 1, 1); + terms[1]->add(2, 0, 1, 2); + terms[2]->add(2, 0, 1, 4); + // field 3 + terms[0]->add(3, 0, 1, 0); + terms[1]->add(3, 0, 1, 1); + terms[2]->add(3, 0, 1, 2); + // field 4 (not complete match) + terms[0]->add(4, 0, 1, 1); + terms[1]->add(4, 0, 1, 2); + // field 5 (not complete match) + terms[0]->add(5, 0, 1, 2); + terms[1]->add(5, 0, 1, 1); + terms[2]->add(5, 0, 1, 0); + HitList hits; + auto * p = static_cast<PhraseQueryNode *>(phrases[0]); + p->evaluateHits(hits); + ASSERT_EQ(3u, hits.size()); + EXPECT_EQ(0u, hits[0].field_id()); + EXPECT_EQ(0u, hits[0].element_id()); + EXPECT_EQ(2u, hits[0].position()); + EXPECT_EQ(1u, hits[1].field_id()); + EXPECT_EQ(0u, hits[1].element_id()); + EXPECT_EQ(6u, hits[1].position()); + EXPECT_EQ(3u, hits[2].field_id()); + EXPECT_EQ(0u, hits[2].element_id()); + EXPECT_EQ(2u, hits[2].position()); + ASSERT_EQ(4u, p->getFieldInfoSize()); + EXPECT_EQ(0u, p->getFieldInfo(0).getHitOffset()); + EXPECT_EQ(1u, p->getFieldInfo(0).getHitCount()); + EXPECT_EQ(1u, p->getFieldInfo(1).getHitOffset()); + EXPECT_EQ(1u, p->getFieldInfo(1).getHitCount()); + EXPECT_EQ(0u, p->getFieldInfo(2).getHitOffset()); // invalid, but will never be used + EXPECT_EQ(0u, p->getFieldInfo(2).getHitCount()); + EXPECT_EQ(2u, p->getFieldInfo(3).getHitOffset()); + EXPECT_EQ(1u, p->getFieldInfo(3).getHitCount()); + EXPECT_TRUE(p->evaluate()); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/query/streaming/same_element_query_node_test.cpp b/searchlib/src/tests/query/streaming/same_element_query_node_test.cpp new file mode 100644 index 00000000000..ece6dc551b2 --- /dev/null +++ b/searchlib/src/tests/query/streaming/same_element_query_node_test.cpp @@ -0,0 +1,135 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchlib/query/streaming/same_element_query_node.h> +#include <vespa/searchlib/query/streaming/queryterm.h> +#include <vespa/searchlib/query/tree/querybuilder.h> +#include <vespa/searchlib/query/tree/simplequery.h> +#include <vespa/searchlib/query/tree/stackdumpcreator.h> +#include <vespa/vespalib/gtest/gtest.h> + +using search::query::QueryBuilder; +using search::query::Node; +using search::query::SimpleQueryNodeTypes; +using search::query::StackDumpCreator; +using search::query::Weight; +using search::streaming::HitList; +using search::streaming::Query; +using search::streaming::QueryNode; +using search::streaming::QueryNodeResultFactory; +using search::streaming::QueryTerm; +using search::streaming::QueryTermList; +using search::streaming::SameElementQueryNode; + +namespace { + +class AllowRewrite : public QueryNodeResultFactory +{ +public: + explicit AllowRewrite(vespalib::stringref index) noexcept : _allowedIndex(index) {} + bool allow_float_terms_rewrite(vespalib::stringref index) const noexcept override { return index == _allowedIndex; } +private: + vespalib::string _allowedIndex; +}; + +} + +TEST(SameElementQueryNodeTest, a_unhandled_sameElement_stack) +{ + const char * stack = "\022\002\026xyz_abcdefghij_xyzxyzxQ\001\vxxxxxx_name\034xxxxxx_xxxx_xxxxxxx_xxxxxxxxE\002\005delta\b<0.00393"; + vespalib::stringref stackDump(stack); + EXPECT_EQ(85u, stackDump.size()); + AllowRewrite empty(""); + const Query q(empty, stackDump); + EXPECT_TRUE(q.valid()); + const QueryNode & root = q.getRoot(); + auto sameElement = dynamic_cast<const SameElementQueryNode *>(&root); + EXPECT_TRUE(sameElement != nullptr); + EXPECT_EQ(2u, sameElement->size()); + EXPECT_EQ("xyz_abcdefghij_xyzxyzx", sameElement->getIndex()); + auto term0 = dynamic_cast<const QueryTerm *>((*sameElement)[0].get()); + EXPECT_TRUE(term0 != nullptr); + auto term1 = dynamic_cast<const QueryTerm *>((*sameElement)[1].get()); + EXPECT_TRUE(term1 != nullptr); +} + +namespace { + void verifyQueryTermNode(const vespalib::string & index, const QueryNode *node) { + EXPECT_TRUE(dynamic_cast<const QueryTerm *>(node) != nullptr); + EXPECT_EQ(index, node->getIndex()); + } +} + +TEST(SameElementQueryNodeTest, test_same_element_evaluate) +{ + QueryBuilder<SimpleQueryNodeTypes> builder; + builder.addSameElement(3, "field", 0, Weight(0)); + { + builder.addStringTerm("a", "f1", 0, Weight(0)); + builder.addStringTerm("b", "f2", 1, Weight(0)); + builder.addStringTerm("c", "f3", 2, Weight(0)); + } + Node::UP node = builder.build(); + vespalib::string stackDump = StackDumpCreator::create(*node); + QueryNodeResultFactory empty; + Query q(empty, stackDump); + auto * sameElem = dynamic_cast<SameElementQueryNode *>(&q.getRoot()); + EXPECT_TRUE(sameElem != nullptr); + EXPECT_EQ("field", sameElem->getIndex()); + EXPECT_EQ(3u, sameElem->size()); + verifyQueryTermNode("field.f1", (*sameElem)[0].get()); + verifyQueryTermNode("field.f2", (*sameElem)[1].get()); + verifyQueryTermNode("field.f3", (*sameElem)[2].get()); + + QueryTermList terms; + q.getLeaves(terms); + EXPECT_EQ(3u, terms.size()); + for (QueryTerm * qt : terms) { + qt->resizeFieldId(3); + } + + // field 0 + terms[0]->add(0, 0, 10, 1); + terms[0]->add(0, 1, 20, 2); + terms[0]->add(0, 2, 30, 3); + terms[0]->add(0, 3, 40, 4); + terms[0]->add(0, 4, 50, 5); + terms[0]->add(0, 5, 60, 6); + + terms[1]->add(1, 0, 70, 7); + terms[1]->add(1, 1, 80, 8); + terms[1]->add(1, 2, 90, 9); + terms[1]->add(1, 4, 100, 10); + terms[1]->add(1, 5, 110, 11); + terms[1]->add(1, 6, 120, 12); + + terms[2]->add(2, 0, 130, 13); + terms[2]->add(2, 2, 140, 14); + terms[2]->add(2, 4, 150, 15); + terms[2]->add(2, 5, 160, 16); + terms[2]->add(2, 6, 170, 17); + HitList hits; + sameElem->evaluateHits(hits); + EXPECT_EQ(4u, hits.size()); + EXPECT_EQ(2u, hits[0].field_id()); + EXPECT_EQ(0u, hits[0].element_id()); + EXPECT_EQ(130, hits[0].element_weight()); + EXPECT_EQ(0u, hits[0].position()); + + EXPECT_EQ(2u, hits[1].field_id()); + EXPECT_EQ(2u, hits[1].element_id()); + EXPECT_EQ(140, hits[1].element_weight()); + EXPECT_EQ(0u, hits[1].position()); + + EXPECT_EQ(2u, hits[2].field_id()); + EXPECT_EQ(4u, hits[2].element_id()); + EXPECT_EQ(150, hits[2].element_weight()); + EXPECT_EQ(0u, hits[2].position()); + + EXPECT_EQ(2u, hits[3].field_id()); + EXPECT_EQ(5u, hits[3].element_id()); + EXPECT_EQ(160, hits[3].element_weight()); + EXPECT_EQ(0u, hits[3].position()); + EXPECT_TRUE(sameElem->evaluate()); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index fe6149e6fba..d2be1d453a2 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -4,6 +4,7 @@ #include <vespa/searchlib/fef/matchdata.h> #include <vespa/searchlib/query/streaming/dot_product_term.h> #include <vespa/searchlib/query/streaming/in_term.h> +#include <vespa/searchlib/query/streaming/phrase_query_node.h> #include <vespa/searchlib/query/streaming/query.h> #include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h> #include <vespa/searchlib/query/streaming/wand_term.h> @@ -23,10 +24,11 @@ using TermType = QueryTerm::Type; using search::fef::SimpleTermData; using search::fef::MatchData; -void assertHit(const Hit & h, size_t expWordpos, size_t expContext, int32_t weight) { - EXPECT_EQ(h.wordpos(), expWordpos); - EXPECT_EQ(h.context(), expContext); - EXPECT_EQ(h.weight(), weight); +void assertHit(const Hit & h, uint32_t exp_field_id, uint32_t exp_element_id, int32_t exp_element_weight, size_t exp_position) { + EXPECT_EQ(h.field_id(), exp_field_id); + EXPECT_EQ(h.element_id(), exp_element_id); + EXPECT_EQ(h.element_weight(), exp_element_weight); + EXPECT_EQ(h.position(), exp_position); } @@ -427,87 +429,19 @@ TEST(StreamingQueryTest, test_get_query_parts) } } -TEST(StreamingQueryTest, test_phrase_evaluate) +TEST(StreamingQueryTest, test_hit) { - QueryBuilder<SimpleQueryNodeTypes> builder; - builder.addPhrase(3, "", 0, Weight(0)); - { - builder.addStringTerm("a", "", 0, Weight(0)); - builder.addStringTerm("b", "", 0, Weight(0)); - builder.addStringTerm("c", "", 0, Weight(0)); - } - Node::UP node = builder.build(); - vespalib::string stackDump = StackDumpCreator::create(*node); - QueryNodeResultFactory empty; - Query q(empty, stackDump); - QueryNodeRefList phrases; - q.getPhrases(phrases); - QueryTermList terms; - q.getLeaves(terms); - for (QueryTerm * qt : terms) { - qt->resizeFieldId(1); - } + // field id + assertHit(Hit( 1, 0, 1, 0), 1, 0, 1, 0); + assertHit(Hit(255, 0, 1, 0), 255, 0, 1, 0); + assertHit(Hit(256, 0, 1, 0), 256, 0, 1, 0); - // field 0 - terms[0]->add(0, 0, 0, 1); - terms[1]->add(1, 0, 0, 1); - terms[2]->add(2, 0, 0, 1); - terms[0]->add(7, 0, 0, 1); - terms[1]->add(8, 0, 1, 1); - terms[2]->add(9, 0, 0, 1); - // field 1 - terms[0]->add(4, 1, 0, 1); - terms[1]->add(5, 1, 0, 1); - terms[2]->add(6, 1, 0, 1); - // field 2 (not complete match) - terms[0]->add(1, 2, 0, 1); - terms[1]->add(2, 2, 0, 1); - terms[2]->add(4, 2, 0, 1); - // field 3 - terms[0]->add(0, 3, 0, 1); - terms[1]->add(1, 3, 0, 1); - terms[2]->add(2, 3, 0, 1); - // field 4 (not complete match) - terms[0]->add(1, 4, 0, 1); - terms[1]->add(2, 4, 0, 1); - // field 5 (not complete match) - terms[0]->add(2, 5, 0, 1); - terms[1]->add(1, 5, 0, 1); - terms[2]->add(0, 5, 0, 1); - HitList hits; - auto * p = static_cast<PhraseQueryNode *>(phrases[0]); - p->evaluateHits(hits); - ASSERT_EQ(3u, hits.size()); - EXPECT_EQ(hits[0].wordpos(), 2u); - EXPECT_EQ(hits[0].context(), 0u); - EXPECT_EQ(hits[1].wordpos(), 6u); - EXPECT_EQ(hits[1].context(), 1u); - EXPECT_EQ(hits[2].wordpos(), 2u); - EXPECT_EQ(hits[2].context(), 3u); - ASSERT_EQ(4u, p->getFieldInfoSize()); - EXPECT_EQ(p->getFieldInfo(0).getHitOffset(), 0u); - EXPECT_EQ(p->getFieldInfo(0).getHitCount(), 1u); - EXPECT_EQ(p->getFieldInfo(1).getHitOffset(), 1u); - EXPECT_EQ(p->getFieldInfo(1).getHitCount(), 1u); - EXPECT_EQ(p->getFieldInfo(2).getHitOffset(), 0u); // invalid, but will never be used - EXPECT_EQ(p->getFieldInfo(2).getHitCount(), 0u); - EXPECT_EQ(p->getFieldInfo(3).getHitOffset(), 2u); - EXPECT_EQ(p->getFieldInfo(3).getHitCount(), 1u); - EXPECT_TRUE(p->evaluate()); -} + // positions + assertHit(Hit(0, 0, 0, 0), 0, 0, 0, 0); + assertHit(Hit(0, 0, 1, 256), 0, 0, 1, 256); + assertHit(Hit(0, 0, -1, 16777215), 0, 0, -1, 16777215); + assertHit(Hit(0, 0, 1, 16777216), 0, 0, 1, 16777216); -TEST(StreamingQueryTest, test_hit) -{ - // positions (0 - (2^24-1)) - assertHit(Hit(0, 0, 0, 0), 0, 0, 0); - assertHit(Hit(256, 0, 0, 1), 256, 0, 1); - assertHit(Hit(16777215, 0, 0, -1), 16777215, 0, -1); - assertHit(Hit(16777216, 0, 0, 1), 0, 1, 1); // overflow - - // contexts (0 - 255) - assertHit(Hit(0, 1, 0, 1), 0, 1, 1); - assertHit(Hit(0, 255, 0, 1), 0, 255, 1); - assertHit(Hit(0, 256, 0, 1), 0, 0, 1); // overflow } void assertInt8Range(const std::string &term, bool expAdjusted, int64_t expLow, int64_t expHigh) { @@ -769,105 +703,6 @@ TEST(StreamingQueryTest, require_that_we_do_not_break_the_stack_on_bad_query) EXPECT_FALSE(term.isValid()); } -TEST(StreamingQueryTest, a_unhandled_sameElement_stack) -{ - const char * stack = "\022\002\026xyz_abcdefghij_xyzxyzxQ\001\vxxxxxx_name\034xxxxxx_xxxx_xxxxxxx_xxxxxxxxE\002\005delta\b<0.00393"; - vespalib::stringref stackDump(stack); - EXPECT_EQ(85u, stackDump.size()); - AllowRewrite empty(""); - const Query q(empty, stackDump); - EXPECT_TRUE(q.valid()); - const QueryNode & root = q.getRoot(); - auto sameElement = dynamic_cast<const SameElementQueryNode *>(&root); - EXPECT_TRUE(sameElement != nullptr); - EXPECT_EQ(2u, sameElement->size()); - EXPECT_EQ("xyz_abcdefghij_xyzxyzx", sameElement->getIndex()); - auto term0 = dynamic_cast<const QueryTerm *>((*sameElement)[0].get()); - EXPECT_TRUE(term0 != nullptr); - auto term1 = dynamic_cast<const QueryTerm *>((*sameElement)[1].get()); - EXPECT_TRUE(term1 != nullptr); -} - -namespace { - void verifyQueryTermNode(const vespalib::string & index, const QueryNode *node) { - EXPECT_TRUE(dynamic_cast<const QueryTerm *>(node) != nullptr); - EXPECT_EQ(index, node->getIndex()); - } -} - -TEST(StreamingQueryTest, test_same_element_evaluate) -{ - QueryBuilder<SimpleQueryNodeTypes> builder; - builder.addSameElement(3, "field", 0, Weight(0)); - { - builder.addStringTerm("a", "f1", 0, Weight(0)); - builder.addStringTerm("b", "f2", 1, Weight(0)); - builder.addStringTerm("c", "f3", 2, Weight(0)); - } - Node::UP node = builder.build(); - vespalib::string stackDump = StackDumpCreator::create(*node); - QueryNodeResultFactory empty; - Query q(empty, stackDump); - auto * sameElem = dynamic_cast<SameElementQueryNode *>(&q.getRoot()); - EXPECT_TRUE(sameElem != nullptr); - EXPECT_EQ("field", sameElem->getIndex()); - EXPECT_EQ(3u, sameElem->size()); - verifyQueryTermNode("field.f1", (*sameElem)[0].get()); - verifyQueryTermNode("field.f2", (*sameElem)[1].get()); - verifyQueryTermNode("field.f3", (*sameElem)[2].get()); - - QueryTermList terms; - q.getLeaves(terms); - EXPECT_EQ(3u, terms.size()); - for (QueryTerm * qt : terms) { - qt->resizeFieldId(3); - } - - // field 0 - terms[0]->add(1, 0, 0, 10); - terms[0]->add(2, 0, 1, 20); - terms[0]->add(3, 0, 2, 30); - terms[0]->add(4, 0, 3, 40); - terms[0]->add(5, 0, 4, 50); - terms[0]->add(6, 0, 5, 60); - - terms[1]->add(7, 1, 0, 70); - terms[1]->add(8, 1, 1, 80); - terms[1]->add(9, 1, 2, 90); - terms[1]->add(10, 1, 4, 100); - terms[1]->add(11, 1, 5, 110); - terms[1]->add(12, 1, 6, 120); - - terms[2]->add(13, 2, 0, 130); - terms[2]->add(14, 2, 2, 140); - terms[2]->add(15, 2, 4, 150); - terms[2]->add(16, 2, 5, 160); - terms[2]->add(17, 2, 6, 170); - HitList hits; - sameElem->evaluateHits(hits); - EXPECT_EQ(4u, hits.size()); - EXPECT_EQ(0u, hits[0].wordpos()); - EXPECT_EQ(2u, hits[0].context()); - EXPECT_EQ(0u, hits[0].elemId()); - EXPECT_EQ(130, hits[0].weight()); - - EXPECT_EQ(0u, hits[1].wordpos()); - EXPECT_EQ(2u, hits[1].context()); - EXPECT_EQ(2u, hits[1].elemId()); - EXPECT_EQ(140, hits[1].weight()); - - EXPECT_EQ(0u, hits[2].wordpos()); - EXPECT_EQ(2u, hits[2].context()); - EXPECT_EQ(4u, hits[2].elemId()); - EXPECT_EQ(150, hits[2].weight()); - - EXPECT_EQ(0u, hits[3].wordpos()); - EXPECT_EQ(2u, hits[3].context()); - EXPECT_EQ(5u, hits[3].elemId()); - EXPECT_EQ(160, hits[3].weight()); - EXPECT_TRUE(sameElem->evaluate()); -} - TEST(StreamingQueryTest, test_nearest_neighbor_query_node) { QueryBuilder<SimpleQueryNodeTypes> builder; @@ -917,8 +752,8 @@ TEST(StreamingQueryTest, test_in_term) td.lookupField(12)->setHandle(1); EXPECT_FALSE(term.evaluate()); auto& q = *term.get_terms().front(); - q.add(0, 11, 0, 1); - q.add(0, 12, 0, 1); + q.add(11, 0, 1, 0); + q.add(12, 0, 1, 0); EXPECT_TRUE(term.evaluate()); MatchData md(MatchData::params().numTermFields(2)); term.unpack_match_data(23, td, md); @@ -944,11 +779,11 @@ TEST(StreamingQueryTest, dot_product_term) td.lookupField(12)->setHandle(1); EXPECT_FALSE(term.evaluate()); auto& q0 = *term.get_terms()[0]; - q0.add(0, 11, 0, -13); - q0.add(0, 12, 0, -17); + q0.add(11, 0, -13, 0); + q0.add(12, 0, -17, 0); auto& q1 = *term.get_terms()[1]; - q1.add(0, 11, 0, 4); - q1.add(0, 12, 0, 9); + q1.add(11, 0, 4, 0); + q1.add(12, 0, 9, 0); EXPECT_TRUE(term.evaluate()); MatchData md(MatchData::params().numTermFields(2)); term.unpack_match_data(23, td, md); @@ -989,11 +824,11 @@ check_wand_term(double limit, const vespalib::string& label) td.lookupField(12)->setHandle(1); EXPECT_FALSE(term.evaluate()); auto& q0 = *term.get_terms()[0]; - q0.add(0, 11, 0, 17); - q0.add(0, 12, 0, 13); + q0.add(11, 0, 17, 0); + q0.add(12, 0, 13, 0); auto& q1 = *term.get_terms()[1]; - q1.add(0, 11, 0, 9); - q1.add(0, 12, 0, 4); + q1.add(11, 0, 9, 0); + q1.add(12, 0, 4, 0); EXPECT_EQ(limit < exp_wand_score_field_11, term.evaluate()); MatchData md(MatchData::params().numTermFields(2)); term.unpack_match_data(23, td, md); @@ -1043,11 +878,11 @@ TEST(StreamingQueryTest, weighted_set_term) td.lookupField(12)->setHandle(1); EXPECT_FALSE(term.evaluate()); auto& q0 = *term.get_terms()[0]; - q0.add(0, 11, 0, 10); - q0.add(0, 12, 0, 10); + q0.add(11, 0, 10, 0); + q0.add(12, 0, 10, 0); auto& q1 = *term.get_terms()[1]; - q1.add(0, 11, 0, 10); - q1.add(0, 12, 0, 10); + q1.add(11, 0, 10, 0); + q1.add(12, 0, 10, 0); EXPECT_TRUE(term.evaluate()); MatchData md(MatchData::params().numTermFields(2)); term.unpack_match_data(23, td, md); diff --git a/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp b/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp index bbd2744119a..90452f1d12b 100644 --- a/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp +++ b/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp @@ -23,12 +23,15 @@ class MyOr : public IntermediateBlueprint { private: public: - double calculate_cost() const final { - return OrFlow::cost_of(get_children()); - } double calculate_relative_estimate() const final { return OrFlow::estimate_of(get_children()); } + double calculate_cost() const final { + return OrFlow::cost_of(get_children(), false); + } + double calculate_strict_cost() const final { + return OrFlow::cost_of(get_children(), true); + } HitEstimate combine(const std::vector<HitEstimate> &data) const override { return max(data); } @@ -37,7 +40,7 @@ public: return mixChildrenFields(); } - void sort(Children &children, bool) const override { + void sort(Children &children, bool, bool) const override { std::sort(children.begin(), children.end(), TieredGreaterEstimate()); } @@ -446,7 +449,7 @@ TEST_F("testChildAndNotCollapsing", Fixture) ); TEST_DO(f.check_not_equal(*sorted, *unsorted)); unsorted->setDocIdLimit(1000); - unsorted = Blueprint::optimize(std::move(unsorted), true); + unsorted = Blueprint::optimize_and_sort(std::move(unsorted), true, true); TEST_DO(f.check_equal(*sorted, *unsorted)); } @@ -486,7 +489,7 @@ TEST_F("testChildAndCollapsing", Fixture) TEST_DO(f.check_not_equal(*sorted, *unsorted)); unsorted->setDocIdLimit(1000); - unsorted = Blueprint::optimize(std::move(unsorted), true); + unsorted = Blueprint::optimize_and_sort(std::move(unsorted), true, true); TEST_DO(f.check_equal(*sorted, *unsorted)); } @@ -525,7 +528,10 @@ TEST_F("testChildOrCollapsing", Fixture) ); TEST_DO(f.check_not_equal(*sorted, *unsorted)); unsorted->setDocIdLimit(1000); - unsorted = Blueprint::optimize(std::move(unsorted), true); + // we sort non-strict here since the default costs of 1/est for + // non-strict/strict leaf iterators makes the order of iterators + // under a strict OR irrelevant. + unsorted = Blueprint::optimize_and_sort(std::move(unsorted), false, true); TEST_DO(f.check_equal(*sorted, *unsorted)); } @@ -569,7 +575,7 @@ TEST_F("testChildSorting", Fixture) TEST_DO(f.check_not_equal(*sorted, *unsorted)); unsorted->setDocIdLimit(1000); - unsorted = Blueprint::optimize(std::move(unsorted), true); + unsorted = Blueprint::optimize_and_sort(std::move(unsorted), true, true); TEST_DO(f.check_equal(*sorted, *unsorted)); } @@ -650,12 +656,13 @@ getExpectedBlueprint() " estimate: HitEstimate {\n" " empty: false\n" " estHits: 9\n" - " relative_estimate: 0.5\n" " cost_tier: 1\n" " tree_size: 2\n" " allow_termwise_eval: false\n" " }\n" - " cost: 1\n" + " relative_estimate: 0\n" + " cost: 0\n" + " strict_cost: 0\n" " sourceId: 4294967295\n" " docid_limit: 0\n" " children: std::vector {\n" @@ -671,12 +678,13 @@ getExpectedBlueprint() " estimate: HitEstimate {\n" " empty: false\n" " estHits: 9\n" - " relative_estimate: 0.5\n" " cost_tier: 1\n" " tree_size: 1\n" " allow_termwise_eval: true\n" " }\n" - " cost: 1\n" + " relative_estimate: 0\n" + " cost: 0\n" + " strict_cost: 0\n" " sourceId: 4294967295\n" " docid_limit: 0\n" " }\n" @@ -702,12 +710,13 @@ getExpectedSlimeBlueprint() { " '[type]': 'HitEstimate'," " empty: false," " estHits: 9," - " relative_estimate: 0.5," " cost_tier: 1," " tree_size: 2," " allow_termwise_eval: false" " }," - " cost: 1.0," + " relative_estimate: 0.0," + " cost: 0.0," + " strict_cost: 0.0," " sourceId: 4294967295," " docid_limit: 0," " children: {" @@ -728,12 +737,13 @@ getExpectedSlimeBlueprint() { " '[type]': 'HitEstimate'," " empty: false," " estHits: 9," - " relative_estimate: 0.5," " cost_tier: 1," " tree_size: 1," " allow_termwise_eval: true" " }," - " cost: 1.0," + " relative_estimate: 0.0," + " cost: 0.0," + " strict_cost: 0.0," " sourceId: 4294967295," " docid_limit: 0" " }" @@ -786,9 +796,9 @@ TEST("requireThatDocIdLimitInjectionWorks") } TEST("Control object sizes") { - EXPECT_EQUAL(40u, sizeof(Blueprint::State)); - EXPECT_EQUAL(40u, sizeof(Blueprint)); - EXPECT_EQUAL(80u, sizeof(LeafBlueprint)); + EXPECT_EQUAL(32u, sizeof(Blueprint::State)); + EXPECT_EQUAL(56u, sizeof(Blueprint)); + EXPECT_EQUAL(96u, sizeof(LeafBlueprint)); } TEST_MAIN() { diff --git a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp index 856ac2391f8..2cf523b508b 100644 --- a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp +++ b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp @@ -77,7 +77,7 @@ void check_sort_order(IntermediateBlueprint &self, BlueprintVector children, std unordered.push_back(child.get()); } // TODO: sort by cost (requires both setDocIdLimit and optimize to be called) - self.sort(children, false); + self.sort(children, true, false); for (size_t i = 0; i < children.size(); ++i) { EXPECT_EQUAL(children[i].get(), unordered[order[i]]); } @@ -130,8 +130,8 @@ TEST("test AndNot Blueprint") { } template <typename BP> -void optimize(std::unique_ptr<BP> &ref) { - auto optimized = Blueprint::optimize(std::move(ref), true); +void optimize(std::unique_ptr<BP> &ref, bool strict) { + auto optimized = Blueprint::optimize_and_sort(std::move(ref), strict, true); ref.reset(dynamic_cast<BP*>(optimized.get())); ASSERT_TRUE(ref); optimized.release(); @@ -144,7 +144,7 @@ TEST("test And propagates updated histestimate") { bp->addChild(ap(MyLeafSpec(200).create<RememberExecuteInfo>()->setSourceId(2))); bp->addChild(ap(MyLeafSpec(2000).create<RememberExecuteInfo>()->setSourceId(2))); bp->setDocIdLimit(5000); - optimize(bp); + optimize(bp, true); bp->fetchPostings(ExecuteInfo::TRUE); EXPECT_EQUAL(3u, bp->childCnt()); for (uint32_t i = 0; i < bp->childCnt(); i++) { @@ -164,7 +164,10 @@ TEST("test Or propagates updated histestimate") { bp->addChild(ap(MyLeafSpec(800).create<RememberExecuteInfo>()->setSourceId(2))); bp->addChild(ap(MyLeafSpec(20).create<RememberExecuteInfo>()->setSourceId(2))); bp->setDocIdLimit(5000); - optimize(bp); + // sort OR as non-strict to get expected order. With strict OR, + // the order would be irrelevant since we use the relative + // estimate as strict_cost for leafs. + optimize(bp, false); bp->fetchPostings(ExecuteInfo::TRUE); EXPECT_EQUAL(4u, bp->childCnt()); for (uint32_t i = 0; i < bp->childCnt(); i++) { @@ -519,16 +522,18 @@ void compare(const Blueprint &bp1, const Blueprint &bp2, bool expect_eq) { auto cmp_hook = [expect_eq](const auto &path, const auto &a, const auto &b) { if (!path.empty() && std::holds_alternative<vespalib::stringref>(path.back())) { vespalib::stringref field = std::get<vespalib::stringref>(path.back()); - if (field == "cost") { + // ignore these fields to enable comparing optimized with unoptimized trees + if (field == "relative_estimate" || field == "cost" || field == "strict_cost") { + auto check_value = [&](double value){ + if ((value > 0.0 && value < 1e-6) || (value > 0.0 && value < 1e-6)) { + fprintf(stderr, " small value at %s: %g\n", path_to_str(path).c_str(), + value); + } + }; + check_value(a.asDouble()); + check_value(b.asDouble()); return true; } - if (field == "relative_estimate") { - double a_val = a.asDouble(); - double b_val = b.asDouble(); - if (a_val != 0.0 && b_val != 0.0 && vespalib::approx_equal(a_val, b_val)) { - return true; - } - } } if (expect_eq) { fprintf(stderr, " mismatch at %s: %s vs %s\n", path_to_str(path).c_str(), @@ -548,13 +553,13 @@ void compare(const Blueprint &bp1, const Blueprint &bp2, bool expect_eq) { } void -optimize_and_compare(Blueprint::UP top, Blueprint::UP expect, bool sort_by_cost = true) { +optimize_and_compare(Blueprint::UP top, Blueprint::UP expect, bool strict = true, bool sort_by_cost = true) { top->setDocIdLimit(1000); expect->setDocIdLimit(1000); TEST_DO(compare(*top, *expect, false)); - top = Blueprint::optimize(std::move(top), sort_by_cost); + top = Blueprint::optimize_and_sort(std::move(top), strict, sort_by_cost); TEST_DO(compare(*top, *expect, true)); - expect = Blueprint::optimize(std::move(expect), sort_by_cost); + expect = Blueprint::optimize_and_sort(std::move(expect), strict, sort_by_cost); TEST_DO(compare(*expect, *top, true)); } @@ -614,7 +619,8 @@ TEST_F("test SourceBlender below OR partial optimization", SourceBlenderTestFixt expect->addChild(addLeafsWithSourceId(std::make_unique<SourceBlenderBlueprint>(f.selector_2), {{10, 1}, {20, 2}})); addLeafs(*expect, {3, 2, 1}); - optimize_and_compare(std::move(top), std::move(expect)); + // NOTE: use non-strict cost based sorting for expected order + optimize_and_compare(std::move(top), std::move(expect), false); } TEST_F("test OR replaced by source blender after full optimization", SourceBlenderTestFixture) { @@ -626,7 +632,8 @@ TEST_F("test OR replaced by source blender after full optimization", SourceBlend expect->addChild(addLeafsWithSourceId(2, std::make_unique<OrBlueprint>(), {{2000, 2}, {200, 2}, {20, 2}})); expect->addChild(addLeafsWithSourceId(1, std::make_unique<OrBlueprint>(), {{1000, 1}, {100, 1}, {10, 1}})); - optimize_and_compare(std::move(top), std::move(expect)); + // NOTE: use non-strict cost based sorting for expected order + optimize_and_compare(std::move(top), std::move(expect), false); } TEST_F("test SourceBlender below AND_NOT optimization", SourceBlenderTestFixture) { @@ -681,11 +688,11 @@ TEST("test empty root node optimization and safeness") { //------------------------------------------------------------------------- auto expect_up = std::make_unique<EmptyBlueprint>(); - EXPECT_EQUAL(expect_up->asString(), Blueprint::optimize(std::move(top1), true)->asString()); - EXPECT_EQUAL(expect_up->asString(), Blueprint::optimize(std::move(top2), true)->asString()); - EXPECT_EQUAL(expect_up->asString(), Blueprint::optimize(std::move(top3), true)->asString()); - EXPECT_EQUAL(expect_up->asString(), Blueprint::optimize(std::move(top4), true)->asString()); - EXPECT_EQUAL(expect_up->asString(), Blueprint::optimize(std::move(top5), true)->asString()); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top1), true, true), true); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top2), true, true), true); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top3), true, true), true); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top4), true, true), true); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top5), true, true), true); } TEST("and with one empty child is optimized away") { @@ -693,11 +700,11 @@ TEST("and with one empty child is optimized away") { Blueprint::UP top = ap((new SourceBlenderBlueprint(*selector))-> addChild(ap(MyLeafSpec(10).create())). addChild(addLeafs(std::make_unique<AndBlueprint>(), {{0, true}, 10, 20}))); - top = Blueprint::optimize(std::move(top), true); + top = Blueprint::optimize_and_sort(std::move(top), true, true); Blueprint::UP expect_up(ap((new SourceBlenderBlueprint(*selector))-> addChild(ap(MyLeafSpec(10).create())). addChild(std::make_unique<EmptyBlueprint>()))); - EXPECT_EQUAL(expect_up->asString(), top->asString()); + compare(*expect_up, *top, true); } struct make { @@ -731,7 +738,7 @@ struct make { return std::move(*this); } make &&leaf(uint32_t estimate) && { - std::unique_ptr<LeafBlueprint> bp(MyLeafSpec(estimate).create()); + std::unique_ptr<MyLeaf> bp(MyLeafSpec(estimate).create()); if (cost_tag != invalid_cost) { bp->set_cost(cost_tag); cost_tag = invalid_cost; @@ -763,7 +770,8 @@ TEST("AND AND collapsing") { TEST("OR OR collapsing") { Blueprint::UP top = make::OR().leafs({1,3,5}).add(make::OR().leafs({2,4})); Blueprint::UP expect = make::OR().leafs({5,4,3,2,1}); - optimize_and_compare(std::move(top), std::move(expect)); + // NOTE: use non-strict cost based sorting for expected order + optimize_and_compare(std::move(top), std::move(expect), false); } TEST("AND_NOT AND_NOT collapsing") { @@ -841,7 +849,8 @@ TEST("test single child optimization") { TEST("test empty OR child optimization") { Blueprint::UP top = addLeafs(std::make_unique<OrBlueprint>(), {{0, true}, 20, {0, true}, 10, {0, true}, 0, 30, {0, true}}); Blueprint::UP expect = addLeafs(std::make_unique<OrBlueprint>(), {30, 20, 10, 0}); - optimize_and_compare(std::move(top), std::move(expect)); + // NOTE: use non-strict cost based sorting for expected order + optimize_and_compare(std::move(top), std::move(expect), false); } TEST("test empty AND_NOT child optimization") { @@ -868,10 +877,10 @@ TEST("require that replaced blueprints retain source id") { addChild(ap(MyLeafSpec(30).create()->setSourceId(55))))); Blueprint::UP expect2_up(ap(MyLeafSpec(30).create()->setSourceId(42))); //------------------------------------------------------------------------- - top1_up = Blueprint::optimize(std::move(top1_up), true); - top2_up = Blueprint::optimize(std::move(top2_up), true); - EXPECT_EQUAL(expect1_up->asString(), top1_up->asString()); - EXPECT_EQUAL(expect2_up->asString(), top2_up->asString()); + top1_up = Blueprint::optimize_and_sort(std::move(top1_up), true, true); + top2_up = Blueprint::optimize_and_sort(std::move(top2_up), true, true); + compare(*expect1_up, *top1_up, true); + compare(*expect2_up, *top2_up, true); EXPECT_EQUAL(13u, top1_up->getSourceId()); EXPECT_EQUAL(42u, top2_up->getSourceId()); } @@ -1181,45 +1190,25 @@ TEST("require_that_unpack_optimization_is_not_overruled_by_equiv") { } } -TEST("require that children of near are not optimized") { - auto top_up = ap((new NearBlueprint(10))-> - addChild(addLeafs(std::make_unique<OrBlueprint>(), {20, {0, true}})). - addChild(addLeafs(std::make_unique<OrBlueprint>(), {{0, true}, 30}))); - auto expect_up = ap((new NearBlueprint(10))-> - addChild(addLeafs(std::make_unique<OrBlueprint>(), {20, {0, true}})). - addChild(addLeafs(std::make_unique<OrBlueprint>(), {{0, true}, 30}))); - top_up = Blueprint::optimize(std::move(top_up), true); - TEST_DO(compare(*top_up, *expect_up, true)); -} - -TEST("require that children of onear are not optimized") { - auto top_up = ap((new ONearBlueprint(10))-> - addChild(addLeafs(std::make_unique<OrBlueprint>(), {20, {0, true}})). - addChild(addLeafs(std::make_unique<OrBlueprint>(), {{0, true}, 30}))); - auto expect_up = ap((new ONearBlueprint(10))-> - addChild(addLeafs(std::make_unique<OrBlueprint>(), {20, {0, true}})). - addChild(addLeafs(std::make_unique<OrBlueprint>(), {{0, true}, 30}))); - top_up = Blueprint::optimize(std::move(top_up), true); - TEST_DO(compare(*top_up, *expect_up, true)); -} - TEST("require that ANDNOT without children is optimized to empty search") { Blueprint::UP top_up = std::make_unique<AndNotBlueprint>(); auto expect_up = std::make_unique<EmptyBlueprint>(); - top_up = Blueprint::optimize(std::move(top_up), true); - EXPECT_EQUAL(expect_up->asString(), top_up->asString()); + top_up = Blueprint::optimize_and_sort(std::move(top_up), true, true); + compare(*expect_up, *top_up, true); } TEST("require that highest cost tier sorts last for OR") { Blueprint::UP top = addLeafsWithCostTier(std::make_unique<OrBlueprint>(), {{50, 1}, {30, 3}, {20, 2}, {10, 1}}); Blueprint::UP expect = addLeafsWithCostTier(std::make_unique<OrBlueprint>(), {{50, 1}, {10, 1}, {20, 2}, {30, 3}}); - optimize_and_compare(std::move(top), std::move(expect), false); + // cost-based sorting would ignore cost tier + optimize_and_compare(std::move(top), std::move(expect), true, false); } TEST("require that highest cost tier sorts last for AND") { Blueprint::UP top = addLeafsWithCostTier(std::make_unique<AndBlueprint>(), {{10, 1}, {20, 3}, {30, 2}, {50, 1}}); Blueprint::UP expect = addLeafsWithCostTier(std::make_unique<AndBlueprint>(), {{10, 1}, {50, 1}, {30, 2}, {20, 3}}); - optimize_and_compare(std::move(top), std::move(expect), false); + // cost-based sorting would ignore cost tier + optimize_and_compare(std::move(top), std::move(expect), true, false); } template<typename BP> @@ -1292,6 +1281,7 @@ void verify_relative_estimate(make &&mk, double expect) { EXPECT_EQUAL(mk.making->estimate(), 0.0); Blueprint::UP bp = std::move(mk).leafs({200,300,950}); bp->setDocIdLimit(1000); + bp = Blueprint::optimize(std::move(bp)); EXPECT_EQUAL(bp->estimate(), expect); } @@ -1329,15 +1319,17 @@ TEST("relative estimate for WEAKAND") { verify_relative_estimate(make::WEAKAND(50), 0.05); } -void verify_cost(make &&mk, double expect) { - EXPECT_EQUAL(mk.making->cost(), 1.0); +void verify_cost(make &&mk, double expect, double expect_strict) { + EXPECT_EQUAL(mk.making->cost(), 0.0); + EXPECT_EQUAL(mk.making->strict_cost(), 0.0); Blueprint::UP bp = std::move(mk) - .cost(1.1).leaf(200) - .cost(1.2).leaf(300) - .cost(1.3).leaf(500); + .cost(1.1).leaf(200) // strict_cost: 0.2*1.1 + .cost(1.2).leaf(300) // strict_cost: 0.3*1.2 + .cost(1.3).leaf(950); // rel_est: 0.5, strict_cost: 1.3 bp->setDocIdLimit(1000); - bp = Blueprint::optimize(std::move(bp), true); + bp = Blueprint::optimize(std::move(bp)); EXPECT_EQUAL(bp->cost(), expect); + EXPECT_EQUAL(bp->strict_cost(), expect_strict); } double calc_cost(std::vector<std::pair<double,double>> list) { @@ -1351,36 +1343,48 @@ double calc_cost(std::vector<std::pair<double,double>> list) { } TEST("cost for OR") { - verify_cost(make::OR(), calc_cost({{1.3, 0.5},{1.2, 0.7},{1.1, 0.8}})); + verify_cost(make::OR(), + calc_cost({{1.3, 0.5},{1.2, 0.7},{1.1, 0.8}}), + calc_cost({{0.2*1.1, 0.8},{0.3*1.2, 0.7},{1.3, 0.5}})); } TEST("cost for AND") { - verify_cost(make::AND(), calc_cost({{1.1, 0.2},{1.2, 0.3},{1.3, 0.5}})); + verify_cost(make::AND(), + calc_cost({{1.1, 0.2},{1.2, 0.3},{1.3, 0.5}}), + calc_cost({{0.2*1.1, 0.2},{1.2, 0.3},{1.3, 0.5}})); } TEST("cost for RANK") { - verify_cost(make::RANK(), 1.1); // first + verify_cost(make::RANK(), 1.1, 0.2*1.1); // first } TEST("cost for ANDNOT") { - verify_cost(make::ANDNOT(), calc_cost({{1.1, 0.2},{1.3, 0.5},{1.2, 0.7}})); + verify_cost(make::ANDNOT(), + calc_cost({{1.1, 0.2},{1.3, 0.5},{1.2, 0.7}}), + calc_cost({{0.2*1.1, 0.2},{1.3, 0.5},{1.2, 0.7}})); } TEST("cost for SB") { InvalidSelector sel; - verify_cost(make::SB(sel), 1.3); // max + verify_cost(make::SB(sel), 1.3, 1.3); // max } TEST("cost for NEAR") { - verify_cost(make::NEAR(1), 3.0 + calc_cost({{1.1, 0.2},{1.2, 0.3},{1.3, 0.5}})); + verify_cost(make::NEAR(1), + 0.2*0.3*0.5 * 3 + calc_cost({{1.1, 0.2},{1.2, 0.3},{1.3, 0.5}}), + 0.2*0.3*0.5 * 3 + calc_cost({{0.2*1.1, 0.2},{1.2, 0.3},{1.3, 0.5}})); } TEST("cost for ONEAR") { - verify_cost(make::ONEAR(1), 3.0 + calc_cost({{1.1, 0.2},{1.2, 0.3},{1.3, 0.5}})); + verify_cost(make::ONEAR(1), + 0.2*0.3*0.5 * 3 + calc_cost({{1.1, 0.2},{1.2, 0.3},{1.3, 0.5}}), + 0.2*0.3*0.5 * 3 + calc_cost({{0.2*1.1, 0.2},{1.2, 0.3},{1.3, 0.5}})); } TEST("cost for WEAKAND") { - verify_cost(make::WEAKAND(1000), calc_cost({{1.3, 0.5},{1.2, 0.7},{1.1, 0.8}})); + verify_cost(make::WEAKAND(1000), + calc_cost({{1.3, 0.5},{1.2, 0.7},{1.1, 0.8}}), + calc_cost({{0.2*1.1, 0.8},{0.3*1.2, 0.7},{1.3, 0.5}})); } TEST_MAIN() { TEST_DEBUG("lhs.out", "rhs.out"); TEST_RUN_ALL(); } diff --git a/searchlib/src/tests/queryeval/blueprint/mysearch.h b/searchlib/src/tests/queryeval/blueprint/mysearch.h index db7dd2adae6..6eb27364c2b 100644 --- a/searchlib/src/tests/queryeval/blueprint/mysearch.h +++ b/searchlib/src/tests/queryeval/blueprint/mysearch.h @@ -104,7 +104,8 @@ public: class MyLeaf : public SimpleLeafBlueprint { using TFMDA = search::fef::TermFieldMatchDataArray; - bool _got_global_filter; + bool _got_global_filter = false; + double _cost = 1.0; public: SearchIterator::UP @@ -113,18 +114,15 @@ public: return std::make_unique<MySearch>("leaf", tfmda, strict); } - MyLeaf() - : SimpleLeafBlueprint(), _got_global_filter(false) - {} - MyLeaf(FieldSpecBaseList fields) - : SimpleLeafBlueprint(std::move(fields)), _got_global_filter(false) - {} - + MyLeaf() : SimpleLeafBlueprint() {} + MyLeaf(FieldSpecBaseList fields) : SimpleLeafBlueprint(std::move(fields)) {} + void set_cost(double value) noexcept { _cost = value; } + double calculate_cost() const override { return _cost; } + MyLeaf &estimate(uint32_t hits, bool empty = false) { setEstimate(HitEstimate(hits, empty)); return *this; } - MyLeaf &cost_tier(uint32_t value) { set_cost_tier(value); return *this; @@ -153,7 +151,7 @@ private: public: explicit MyLeafSpec(uint32_t estHits, bool empty = false) - : _fields(), _estimate(estHits, empty), _cost_tier(0), _want_global_filter(false) {} + : _fields(), _estimate(estHits, empty), _cost_tier(0), _want_global_filter(false) {} MyLeafSpec &addField(uint32_t fieldId, uint32_t handle) { _fields.add(FieldSpecBase(fieldId, handle)); diff --git a/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp b/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp index 1180206279d..4fc8922b9a3 100644 --- a/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp +++ b/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp @@ -48,7 +48,10 @@ concept ChildCollector = requires(T a, std::unique_ptr<Blueprint> bp) { // inherit Blueprint to capture the default filter factory struct DefaultBlueprint : Blueprint { double calculate_relative_estimate() const override { abort(); } - void optimize(Blueprint* &, OptimizePass, bool) override { abort(); } + double calculate_cost() const override { abort(); } + double calculate_strict_cost() const override { abort(); } + void optimize(Blueprint* &, OptimizePass) override { abort(); } + void sort(bool, bool) override { abort(); } const State &getState() const override { abort(); } void fetchPostings(const ExecuteInfo &) override { abort(); } void freeze() override { abort(); } diff --git a/searchlib/src/tests/queryeval/or_speed/CMakeLists.txt b/searchlib/src/tests/queryeval/or_speed/CMakeLists.txt new file mode 100644 index 00000000000..950a3a965be --- /dev/null +++ b/searchlib/src/tests/queryeval/or_speed/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(searchlib_or_speed_test_app TEST + SOURCES + or_speed_test.cpp + DEPENDS + searchlib + GTest::GTest +) +vespa_add_test(NAME searchlib_or_speed_test_app COMMAND searchlib_or_speed_test_app) diff --git a/searchlib/src/tests/queryeval/or_speed/or_speed_test.cpp b/searchlib/src/tests/queryeval/or_speed/or_speed_test.cpp new file mode 100644 index 00000000000..c27302e818f --- /dev/null +++ b/searchlib/src/tests/queryeval/or_speed/or_speed_test.cpp @@ -0,0 +1,336 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchlib/common/bitvector.h> +#include <vespa/searchlib/common/bitvectoriterator.h> +#include <vespa/searchlib/queryeval/orsearch.h> +#include <vespa/searchlib/queryeval/unpackinfo.h> +#include <vespa/searchlib/queryeval/multibitvectoriterator.h> +#include <vespa/searchlib/fef/termfieldmatchdata.h> +#include <vespa/vespalib/util/stash.h> +#include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/benchmark_timer.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <vector> +#include <random> + +using namespace search; +using namespace vespalib; +using search::queryeval::SearchIterator; +using search::queryeval::OrSearch; +using search::queryeval::UnpackInfo; +using TMD = search::fef::TermFieldMatchData; +using vespalib::make_string_short::fmt; +using Impl = OrSearch::StrictImpl; + +double budget = 5.0; +size_t bench_docs = 10'000'000; +bool bench_mode = false; +constexpr uint32_t default_seed = 5489u; +std::mt19937 gen(default_seed); + +const char *impl_str(Impl impl) { + if (impl == Impl::PLAIN) { return "plain"; } + if (impl == Impl::HEAP) { return " heap"; } + return "unknown"; +} +const char *bool_str(bool bit) { return bit ? "true" : "false"; } +const char *leaf_str(bool array) { return array ? "A" : "B"; } +const char *opt_str(bool optimize) { return optimize ? "OPT" : "std"; } + +BitVector::UP make_bitvector(size_t size, size_t num_bits) { + EXPECT_GT(size, num_bits); + auto bv = BitVector::create(size); + size_t bits_left = num_bits; + // bit 0 is never set since it is reserved + // all other bits have equal probability to be set + for (size_t i = 1; i < size; ++i) { + std::uniform_int_distribution<size_t> space(0,size-i-1); + if (space(gen) < bits_left) { + bv->setBit(i); + --bits_left; + } + } + bv->invalidateCachedCount(); + EXPECT_EQ(bv->countTrueBits(), num_bits); + return bv; +} + +// simple strict array-based iterator +// This class has 2 uses: +// 1: better performance for few hits compared to bitvector +// 2: not a bitvector, useful when testing multi-bitvector interactions +struct ArrayIterator : SearchIterator { + uint32_t my_offset = 0; + uint32_t my_limit; + std::vector<uint32_t> my_hits; + TMD &my_match_data; + ArrayIterator(const BitVector &bv, TMD &tmd) + : my_limit(bv.size()), my_match_data(tmd) + { + uint32_t next = bv.getStartIndex(); + for (;;) { + next = bv.getNextTrueBit(next); + if (next >= my_limit) { + break; + } + my_hits.push_back(next++); + } + my_match_data.reset(0); + } + void initRange(uint32_t begin, uint32_t end) final { + SearchIterator::initRange(begin, end); + my_offset = 0; + } + void doSeek(uint32_t docid) final { + while (my_offset < my_hits.size() && my_hits[my_offset] < docid) { + ++my_offset; + } + if (my_offset < my_hits.size()) { + setDocId(my_hits[my_offset]); + } else { + setAtEnd(); + } + } + Trinary is_strict() const final { return Trinary::True; } + void doUnpack(uint32_t docId) final { my_match_data.resetOnlyDocId(docId); } +}; + +struct OrSetup { + uint32_t docid_limit; + bool unpack_all = true; + bool unpack_none = true; + std::vector<std::unique_ptr<TMD>> match_data; + std::vector<BitVector::UP> child_hits; + std::vector<bool> use_array; + OrSetup(uint32_t docid_limit_in) noexcept : docid_limit(docid_limit_in) {} + size_t per_child(double target, size_t child_cnt) { + size_t result = (docid_limit * target) / child_cnt; + return (result >= docid_limit) ? (docid_limit - 1) : result; + } + bool should_use_array(size_t hits) { + return (docid_limit / hits) >= 32; + } + OrSetup &add(size_t num_hits, bool use_array_in, bool need_unpack) { + match_data.push_back(std::make_unique<TMD>()); + child_hits.push_back(make_bitvector(docid_limit, num_hits)); + use_array.push_back(use_array_in); + if (need_unpack) { + match_data.back()->setNeedNormalFeatures(true); + match_data.back()->setNeedInterleavedFeatures(true); + unpack_none = false; + } else { + match_data.back()->tagAsNotNeeded(); + unpack_all = false; + } + return *this; + } + SearchIterator::UP make_leaf(size_t i) { + if (use_array[i]) { + return std::make_unique<ArrayIterator>(*child_hits[i], *match_data[i]); + } else { + return BitVectorIterator::create(child_hits[i].get(), *match_data[i], true); + } + } + SearchIterator::UP make_or(Impl impl, bool optimize) { + assert(!child_hits.empty()); + if (child_hits.size() == 1) { + // use child directly if there is only one + return make_leaf(0); + } + std::vector<SearchIterator::UP> children; + for (size_t i = 0; i < child_hits.size(); ++i) { + children.push_back(make_leaf(i)); + } + UnpackInfo unpack; + if (unpack_all) { + unpack.forceAll(); + } else if (!unpack_none) { + for (size_t i = 0; i < match_data.size(); ++i) { + if (!match_data[i]->isNotNeeded()) { + unpack.add(i); + } + } + } + auto result = OrSearch::create(std::move(children), true, unpack, impl); + if (optimize) { + result = queryeval::MultiBitVectorIteratorBase::optimize(std::move(result)); + } + return result; + } + OrSetup &prepare_bm(size_t child_cnt, size_t hits_per_child) { + for (size_t i = 0; i < child_cnt; ++i) { + add(hits_per_child, should_use_array(hits_per_child), false); + } + return *this; + } + std::pair<size_t,double> bm_search_ms(Impl impl, bool optimized) { + auto search_up = make_or(impl, optimized); + SearchIterator &search = *search_up; + size_t hits = 0; + BenchmarkTimer timer(budget); + while (timer.has_budget()) { + timer.before(); + hits = 0; + search.initRange(1, docid_limit); + uint32_t docid = search.seekFirst(1); + while (docid < docid_limit) { + ++hits; + docid = search.seekNext(docid + 1); + // no unpack + } + timer.after(); + } + return std::make_pair(hits, timer.min_time() * 1000.0); + } + void verify_not_match(uint32_t docid) { + for (size_t i = 0; i < match_data.size(); ++i) { + EXPECT_FALSE(child_hits[i]->testBit(docid)); + } + } + void verify_match(uint32_t docid, bool unpacked, bool check_skipped_unpack) { + bool match = false; + for (size_t i = 0; i < match_data.size(); ++i) { + if (child_hits[i]->testBit(docid)) { + match = true; + if (unpacked) { + if (!match_data[i]->isNotNeeded()) { + EXPECT_EQ(match_data[i]->getDocId(), docid) << "unpack was needed"; + } else if (check_skipped_unpack) { + EXPECT_NE(match_data[i]->getDocId(), docid) << "unpack was not needed"; + } + } else { + EXPECT_NE(match_data[i]->getDocId(), docid) << "document was not unpacked"; + } + } else { + EXPECT_NE(match_data[i]->getDocId(), docid) << "document was not a match"; + } + } + EXPECT_TRUE(match); + } + void reset_match_data() { + // this is needed since we re-search the same docid space + // multiple times and may end up finding a result we are not + // unpacking that was unpacked in the last iteration thus + // breaking the "document was not unpacked" test condition. + for (auto &tmd: match_data) { + tmd->resetOnlyDocId(0); + } + } + void verify_seek_unpack(Impl impl, bool check_skipped_unpack, bool optimized) { + auto search_up = make_or(impl, optimized); + SearchIterator &search = *search_up; + for (size_t unpack_nth: {1, 3}) { + for (size_t skip: {1, 31}) { + uint32_t hits = 0; + uint32_t check_at = 1; + search.initRange(1, docid_limit); + uint32_t docid = search.seekFirst(1); + while (docid < docid_limit) { + for (; check_at < docid; ++check_at) { + verify_not_match(check_at); + } + if (++hits % unpack_nth == 0) { + search.unpack(docid); + verify_match(check_at, true, check_skipped_unpack); + } else { + verify_match(check_at, false, check_skipped_unpack); + } + check_at = docid + skip; + docid = search.seekNext(docid + skip); + } + for (; check_at < docid_limit; ++check_at) { + verify_not_match(check_at); + } + reset_match_data(); + } + } + } + ~OrSetup(); +}; +OrSetup::~OrSetup() = default; + +TEST(OrSpeed, array_iterator_seek_unpack) { + OrSetup setup(100); + setup.add(10, true, true); + setup.verify_seek_unpack(Impl::PLAIN, true, false); +} + +TEST(OrSpeed, or_seek_unpack) { + for (bool optimize: {false, true}) { + for (double target: {0.1, 0.5, 1.0, 10.0}) { + for (int unpack: {0,1,2}) { + OrSetup setup(1000); + size_t part = setup.per_child(target, 13); + for (size_t i = 0; i < 13; ++i) { + bool use_array = (i/2)%2 == 0; + bool need_unpack = unpack > 0; + if (unpack == 2 && i % 2 == 0) { + need_unpack = false; + } + setup.add(part, use_array, need_unpack); + } + for (auto impl: {Impl::PLAIN, Impl::HEAP}) { + SCOPED_TRACE(fmt("impl: %s, optimize: %s, part: %zu, unpack: %d", + impl_str(impl), bool_str(optimize), part, unpack)); + setup.verify_seek_unpack(impl, true, optimize); + } + } + } + } +} + +TEST(OrSpeed, bm_array_vs_bitvector) { + if (!bench_mode) { + fprintf(stdout, "[ SKIPPING ] run with 'bench' parameter to activate\n"); + return; + } + for (size_t one_of: {16, 32, 64}) { + double target = 1.0 / one_of; + size_t hits = target * bench_docs; + OrSetup setup(bench_docs); + setup.add(hits, false, false); + for (bool use_array: {false, true}) { + setup.use_array[0] = use_array; + auto result = setup.bm_search_ms(Impl::PLAIN, false); + fprintf(stderr, "LEAF(%s): (one of %4zu) hits: %8zu, time: %10.3f ms, time per hits: %10.3f ns\n", + leaf_str(use_array), one_of, result.first, result.second, (result.second * 1000.0 * 1000.0) / result.first); + } + } +} + +TEST(OrSpeed, bm_strict_or) { + if (!bench_mode) { + fprintf(stdout, "[ SKIPPING ] run with 'bench' parameter to activate\n"); + return; + } + for (double target: {0.001, 0.01, 0.1, 0.5, 1.0, 10.0}) { + for (size_t child_cnt: {2, 3, 4, 5, 10, 100, 250, 500, 1000}) { + for (bool optimize: {false, true}) { + OrSetup setup(bench_docs); + size_t part = setup.per_child(target, child_cnt); + bool use_array = setup.should_use_array(part); + if (part > 0 && (!use_array || !optimize)) { + setup.prepare_bm(child_cnt, part); + for (auto impl: {Impl::PLAIN, Impl::HEAP}) { + auto result = setup.bm_search_ms(impl, optimize); + fprintf(stderr, "OR bench(%s, %s, children: %4zu, hits_per_child: %8zu %s): " + "total_hits: %8zu, time: %10.3f ms, time per hits: %10.3f ns\n", + impl_str(impl), opt_str(optimize), child_cnt, part, leaf_str(use_array), + result.first, result.second, (result.second * 1000.0 * 1000.0) / result.first); + } + } + } + } + } +} + +int main(int argc, char **argv) { + if (argc > 1 && (argv[1] == std::string("bench"))) { + fprintf(stderr, "running in benchmarking mode\n"); + bench_mode = true; + ++argv; + --argc; + } + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp b/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp index a9f549a0bd9..7a7abb20cdf 100644 --- a/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp +++ b/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp @@ -626,12 +626,13 @@ TEST(ParallelWeakAndTest, require_that_asString_on_blueprint_works) " estimate: HitEstimate {\n" " empty: false\n" " estHits: 2\n" - " relative_estimate: 0.5\n" " cost_tier: 1\n" " tree_size: 2\n" " allow_termwise_eval: false\n" " }\n" - " cost: 1\n" + " relative_estimate: 0\n" + " cost: 0\n" + " strict_cost: 0\n" " sourceId: 4294967295\n" " docid_limit: 0\n" " _weights: std::vector {\n" @@ -650,12 +651,13 @@ TEST(ParallelWeakAndTest, require_that_asString_on_blueprint_works) " estimate: HitEstimate {\n" " empty: false\n" " estHits: 2\n" - " relative_estimate: 0.5\n" " cost_tier: 1\n" " tree_size: 1\n" " allow_termwise_eval: true\n" " }\n" - " cost: 1\n" + " relative_estimate: 0\n" + " cost: 0\n" + " strict_cost: 0\n" " sourceId: 4294967295\n" " docid_limit: 0\n" " }\n" diff --git a/searchlib/src/tests/queryeval/same_element/same_element_test.cpp b/searchlib/src/tests/queryeval/same_element/same_element_test.cpp index 7c535e5d3d5..c9fcb472b68 100644 --- a/searchlib/src/tests/queryeval/same_element/same_element_test.cpp +++ b/searchlib/src/tests/queryeval/same_element/same_element_test.cpp @@ -46,7 +46,7 @@ std::unique_ptr<SameElementBlueprint> make_blueprint(const std::vector<FakeResul } Blueprint::UP finalize(Blueprint::UP bp, bool strict) { - Blueprint::UP result = Blueprint::optimize(std::move(bp), true); + Blueprint::UP result = Blueprint::optimize_and_sort(std::move(bp), true, true); result->fetchPostings(ExecuteInfo::createForTest(strict)); result->freeze(); return result; diff --git a/searchlib/src/tests/util/folded_string_compare/folded_string_compare_test.cpp b/searchlib/src/tests/util/folded_string_compare/folded_string_compare_test.cpp index 4e6e565022a..5cc299983f0 100644 --- a/searchlib/src/tests/util/folded_string_compare/folded_string_compare_test.cpp +++ b/searchlib/src/tests/util/folded_string_compare/folded_string_compare_test.cpp @@ -5,6 +5,7 @@ #include <vespa/vespalib/stllike/string.h> #include <vespa/vespalib/text/lowercase.h> #include <vespa/vespalib/text/utf8.h> +#include <algorithm> using search::FoldedStringCompare; using vespalib::LowerCase; diff --git a/searchlib/src/vespa/searchcommon/attribute/basictype.cpp b/searchlib/src/vespa/searchcommon/attribute/basictype.cpp index c63d07ca130..0312154690c 100644 --- a/searchlib/src/vespa/searchcommon/attribute/basictype.cpp +++ b/searchlib/src/vespa/searchcommon/attribute/basictype.cpp @@ -35,4 +35,13 @@ BasicType::asType(const vespalib::string &t) return NONE; } +bool +BasicType::is_integer_type() const noexcept +{ + return (_type == INT8) || + (_type == INT16) || + (_type == INT32) || + (_type == INT64); +} + } diff --git a/searchlib/src/vespa/searchcommon/attribute/basictype.h b/searchlib/src/vespa/searchcommon/attribute/basictype.h index 35200b3f62d..4bdee5ecfd9 100644 --- a/searchlib/src/vespa/searchcommon/attribute/basictype.h +++ b/searchlib/src/vespa/searchcommon/attribute/basictype.h @@ -36,6 +36,7 @@ class BasicType Type type() const noexcept { return _type; } const char * asString() const noexcept { return asString(_type); } size_t fixedSize() const noexcept { return fixedSize(_type); } + bool is_integer_type() const noexcept; static BasicType fromType(bool) noexcept { return BOOL; } static BasicType fromType(int8_t) noexcept { return INT8; } static BasicType fromType(int16_t) noexcept { return INT16; } diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index 5d689f5bd81..d5c67664a5e 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -389,12 +389,6 @@ private: //----------------------------------------------------------------------------- - - - - -//----------------------------------------------------------------------------- - class DirectWandBlueprint : public queryeval::ComplexLeafBlueprint { private: @@ -494,69 +488,6 @@ AttributeFieldBlueprint::getRange(vespalib::string &from, vespalib::string &to) return false; } -//----------------------------------------------------------------------------- - -class DirectAttributeBlueprint : public queryeval::SimpleLeafBlueprint -{ -private: - const IAttributeVector &_iattr; - const IDocidWithWeightPostingStore &_attr; - vespalib::datastore::EntryRef _dictionary_snapshot; - IDirectPostingStore::LookupResult _dict_entry; - -public: - DirectAttributeBlueprint(const FieldSpec &field, const IAttributeVector &iattr, - const IDocidWithWeightPostingStore &attr, - const IDirectPostingStore::LookupKey & key) - : SimpleLeafBlueprint(field), - _iattr(iattr), - _attr(attr), - _dictionary_snapshot(_attr.get_dictionary_snapshot()), - _dict_entry(_attr.lookup(key, _dictionary_snapshot)) - { - setEstimate(HitEstimate(_dict_entry.posting_size, (_dict_entry.posting_size == 0))); - } - - SearchIterator::UP createLeafSearch(const TermFieldMatchDataArray &tfmda, bool strict) const override { - assert(tfmda.size() == 1); - if (_dict_entry.posting_size == 0) { - return std::make_unique<queryeval::EmptySearch>(); - } - if (tfmda[0]->isNotNeeded()) { - auto bitvector_iterator = _attr.make_bitvector_iterator(_dict_entry.posting_idx, get_docid_limit(), *tfmda[0], strict); - if (bitvector_iterator) { - return bitvector_iterator; - } - } - if (_attr.has_btree_iterator(_dict_entry.posting_idx)) { - return std::make_unique<queryeval::DocidWithWeightSearchIterator>(*tfmda[0], _attr, _dict_entry); - } else { - return _attr.make_bitvector_iterator(_dict_entry.posting_idx, get_docid_limit(), *tfmda[0], strict); - } - } - - SearchIteratorUP createFilterSearch(bool strict, FilterConstraint constraint) const override { - (void) constraint; // We provide an iterator with exact results, so no need to take constraint into consideration. - auto wrapper = std::make_unique<FilterWrapper>(getState().numFields()); - wrapper->wrap(createLeafSearch(wrapper->tfmda(), strict)); - return wrapper; - } - - void visitMembers(vespalib::ObjectVisitor &visitor) const override { - LeafBlueprint::visitMembers(visitor); - visit_attribute(visitor, _iattr); - } - std::unique_ptr<queryeval::MatchingElementsSearch> create_matching_elements_search(const MatchingElementsFields &fields) const override { - if (fields.has_field(_iattr.getName())) { - return queryeval::MatchingElementsSearch::create(_iattr, _dictionary_snapshot, vespalib::ConstArrayRef<IDirectPostingStore::LookupResult>(&_dict_entry, 1)); - } else { - return {}; - } - } -}; - -//----------------------------------------------------------------------------- - bool check_valid_diversity_attr(const IAttributeVector *attr) { if ((attr == nullptr) || attr->hasMultiValue()) { return false; @@ -579,8 +510,7 @@ private: const IDocidWithWeightPostingStore *_dwwps; vespalib::string _scratchPad; - bool use_docid_with_weight_posting_store() const { - // TODO: Relax requirement on always having weight iterator for query operators where that makes sense. + bool has_always_btree_iterators_with_docid_and_weight() const { return (_dwwps != nullptr) && (_dwwps->has_always_btree_iterator()); } @@ -598,15 +528,6 @@ public: ~CreateBlueprintVisitor() override; template <class TermNode> - void visitSimpleTerm(TermNode &n) { - if (use_docid_with_weight_posting_store() && !_field.isFilter() && n.isRanked() && !Term::isPossibleRangeTerm(n.getTerm())) { - NodeAsKey key(n, _scratchPad); - setResult(std::make_unique<DirectAttributeBlueprint>(_field, _attr, *_dwwps, key)); - } else { - visitTerm(n); - } - } - template <class TermNode> void visitTerm(TermNode &n) { SearchContextParams scParams = createContextParams(_field.isFilter()); scParams.fuzzy_matching_algorithm(getRequestContext().get_attribute_blueprint_params().fuzzy_matching_algorithm); @@ -628,7 +549,7 @@ public: } } - void visit(NumberTerm & n) override { visitSimpleTerm(n); } + void visit(NumberTerm & n) override { visitTerm(n); } void visit(LocationTerm &n) override { visitLocation(n); } void visit(PrefixTerm & n) override { visitTerm(n); } @@ -652,7 +573,7 @@ public: } } - void visit(StringTerm & n) override { visitSimpleTerm(n); } + void visit(StringTerm & n) override { visitTerm(n); } void visit(SubstringTerm & n) override { query::SimpleRegExpTerm re(vespalib::RegexpUtil::make_from_substring(n.getTerm()), n.getView(), n.getId(), n.getWeight()); @@ -680,32 +601,41 @@ public: return std::make_unique<QueryTermUCS4>(term, QueryTermSimple::Type::WORD); } - void visit(query::WeightedSetTerm &n) override { - bool isSingleValue = !_attr.hasMultiValue(); - bool isString = (_attr.isStringType() && _attr.hasEnum()); - bool isInteger = _attr.isIntegerType(); - if (isSingleValue && (isString || isInteger)) { - auto ws = std::make_unique<AttributeWeightedSetBlueprint>(_field, _attr); - SearchContextParams scParams = createContextParams(); - for (size_t i = 0; i < n.getNumTerms(); ++i) { - auto term = n.getAsString(i); - ws->addToken(_attr.createSearchContext(extractTerm(term.first, isInteger), scParams), term.second.percent()); - } - setResult(std::move(ws)); + template <typename TermType, typename SearchType> + void visit_wset_or_in_term(TermType& n) { + if (_dps != nullptr) { + auto* bp = new attribute::DirectMultiTermBlueprint<IDocidPostingStore, SearchType> + (_field, _attr, *_dps, n.getNumTerms()); + createDirectMultiTerm(bp, n); + } else if (_dwwps != nullptr) { + auto* bp = new attribute::DirectMultiTermBlueprint<IDocidWithWeightPostingStore, SearchType> + (_field, _attr, *_dwwps, n.getNumTerms()); + createDirectMultiTerm(bp, n); } else { - if (use_docid_with_weight_posting_store()) { - auto *bp = new attribute::DirectMultiTermBlueprint<IDocidWithWeightPostingStore, queryeval::WeightedSetTermSearch> - (_field, _attr, *_dwwps, n.getNumTerms()); - createDirectMultiTerm(bp, n); + bool isSingleValue = !_attr.hasMultiValue(); + bool isString = (_attr.isStringType() && _attr.hasEnum()); + bool isInteger = _attr.isIntegerType(); + if (isSingleValue && (isString || isInteger)) { + auto ws = std::make_unique<AttributeWeightedSetBlueprint>(_field, _attr); + SearchContextParams scParams = createContextParams(); + for (size_t i = 0; i < n.getNumTerms(); ++i) { + auto term = n.getAsString(i); + ws->addToken(_attr.createSearchContext(extractTerm(term.first, isInteger), scParams), term.second.percent()); + } + setResult(std::move(ws)); } else { - auto *bp = new WeightedSetTermBlueprint(_field); + auto* bp = new WeightedSetTermBlueprint(_field); createShallowWeightedSet(bp, n, _field, _attr.isIntegerType()); } } } + void visit(query::WeightedSetTerm &n) override { + visit_wset_or_in_term<query::WeightedSetTerm, queryeval::WeightedSetTermSearch>(n); + } + void visit(query::DotProduct &n) override { - if (use_docid_with_weight_posting_store()) { + if (has_always_btree_iterators_with_docid_and_weight()) { auto *bp = new attribute::DirectMultiTermBlueprint<IDocidWithWeightPostingStore, queryeval::DotProductSearch> (_field, _attr, *_dwwps, n.getNumTerms()); createDirectMultiTerm(bp, n); @@ -716,7 +646,7 @@ public: } void visit(query::WandTerm &n) override { - if (use_docid_with_weight_posting_store()) { + if (has_always_btree_iterators_with_docid_and_weight()) { auto *bp = new DirectWandBlueprint(_field, *_dwwps, n.getTargetNumHits(), n.getScoreThreshold(), n.getThresholdBoostFactor(), n.getNumTerms()); @@ -731,18 +661,7 @@ public: } void visit(query::InTerm &n) override { - if (_dps != nullptr) { - auto* bp = new attribute::DirectMultiTermBlueprint<IDocidPostingStore, attribute::InTermSearch> - (_field, _attr, *_dps, n.getNumTerms()); - createDirectMultiTerm(bp, n); - } else if (_dwwps != nullptr) { - auto* bp = new attribute::DirectMultiTermBlueprint<IDocidWithWeightPostingStore, attribute::InTermSearch> - (_field, _attr, *_dwwps, n.getNumTerms()); - createDirectMultiTerm(bp, n); - } else { - auto* bp = new WeightedSetTermBlueprint(_field); - createShallowWeightedSet(bp, n, _field, _attr.isIntegerType()); - } + visit_wset_or_in_term<query::InTerm, attribute::InTermSearch>(n); } void fail_nearest_neighbor_term(query::NearestNeighborTerm&n, const vespalib::string& error_msg) { diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp index 01148c11c9c..d0353ab8947 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp @@ -6,12 +6,12 @@ #include <vespa/searchlib/common/bitvector.h> #include <vespa/searchlib/fef/matchdatalayout.h> #include <vespa/searchlib/query/query_term_ucs4.h> +#include <vespa/searchlib/queryeval/filter_wrapper.h> +#include <vespa/searchlib/queryeval/orsearch.h> #include <vespa/searchlib/queryeval/weighted_set_term_search.h> #include <vespa/vespalib/objects/visit.h> -#include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/stllike/hash_map.hpp> -#include <vespa/searchlib/queryeval/filter_wrapper.h> -#include <vespa/searchlib/queryeval/orsearch.h> +#include <vespa/vespalib/util/stringfmt.h> namespace search { @@ -38,6 +38,7 @@ class StringEnumWrapper : public AttrWrapper { public: using TokenT = uint32_t; + static constexpr bool unpack_weights = true; explicit StringEnumWrapper(const IAttributeVector & attr) : AttrWrapper(attr) {} auto mapToken(const ISearchContext &context) const { @@ -52,6 +53,7 @@ class IntegerWrapper : public AttrWrapper { public: using TokenT = uint64_t; + static constexpr bool unpack_weights = true; explicit IntegerWrapper(const IAttributeVector & attr) : AttrWrapper(attr) {} std::vector<int64_t> mapToken(const ISearchContext &context) const { std::vector<int64_t> result; diff --git a/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.cpp b/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.cpp index 9d6e2e9815d..5f3ab9cd3d8 100644 --- a/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.cpp +++ b/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.cpp @@ -54,6 +54,11 @@ DfaFuzzyMatcher::DfaFuzzyMatcher(std::string_view target, uint8_t max_edits, uin _successor = _prefix; } +DfaFuzzyMatcher::DfaFuzzyMatcher(std::string_view target, uint8_t max_edits, uint32_t prefix_size, bool cased) + : DfaFuzzyMatcher(target, max_edits, prefix_size, cased, LevenshteinDfa::DfaType::Table) +{ +} + DfaFuzzyMatcher::~DfaFuzzyMatcher() = default; const char* @@ -69,10 +74,10 @@ DfaFuzzyMatcher::skip_prefix(const char* word) const } bool -DfaFuzzyMatcher::is_match(const char* word) const +DfaFuzzyMatcher::is_match(std::string_view word) const { if (_prefix_size > 0) { - Utf8ReaderForZTS reader(word); + Utf8Reader reader(word.data(), word.size()); size_t pos = 0; for (; pos < _prefix.size() && reader.hasMore(); ++pos) { uint32_t code_point = reader.getChar(); @@ -89,7 +94,7 @@ DfaFuzzyMatcher::is_match(const char* word) const if (pos != _prefix_size) { return false; } - word = reader.get_current_ptr(); + word = word.substr(reader.getPos()); } auto match = _dfa.match(word); return match.matches(); diff --git a/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.h b/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.h index 51457129637..653af602c0d 100644 --- a/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.h +++ b/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.h @@ -27,9 +27,14 @@ private: const char* skip_prefix(const char* word) const; public: DfaFuzzyMatcher(std::string_view target, uint8_t max_edits, uint32_t prefix_size, bool cased, vespalib::fuzzy::LevenshteinDfa::DfaType dfa_type); + DfaFuzzyMatcher(std::string_view target, uint8_t max_edits, uint32_t prefix_size, bool cased); // Defaults to table-based DFA ~DfaFuzzyMatcher(); - bool is_match(const char *word) const; + [[nodiscard]] static constexpr bool supports_max_edits(uint8_t edits) noexcept { + return (edits == 1 || edits == 2); + } + + [[nodiscard]] bool is_match(std::string_view word) const; /* * If prefix size is nonzero then this variant of is_match() @@ -40,7 +45,7 @@ public: * functionality in the dictionary. */ template <typename DictionaryConstIteratorType> - bool is_match(const char* word, DictionaryConstIteratorType& itr, const DfaStringComparator::DataStoreType& data_store); + [[nodiscard]] bool is_match(const char* word, DictionaryConstIteratorType& itr, const DfaStringComparator::DataStoreType& data_store); }; } diff --git a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h index 066b70481dc..485427391ad 100644 --- a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h +++ b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h @@ -35,6 +35,8 @@ private: using IteratorType = typename PostingStoreType::IteratorType; using IteratorWeights = std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>>; + bool use_hash_filter(bool strict) const; + IteratorWeights create_iterators(std::vector<IteratorType>& btree_iterators, std::vector<std::unique_ptr<queryeval::SearchIterator>>& bitvectors, bool use_bitvector_when_available, @@ -44,7 +46,7 @@ private: std::vector<std::unique_ptr<queryeval::SearchIterator>>&& bitvectors, bool strict) const; - template <bool filter_search, bool need_match_data> + template <bool filter_search> std::unique_ptr<queryeval::SearchIterator> create_search_helper(const fef::TermFieldMatchDataArray& tfmda, bool strict) const; diff --git a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp index f195e97fee0..0a3b24142a5 100644 --- a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp +++ b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp @@ -8,6 +8,7 @@ #include <vespa/searchlib/queryeval/emptysearch.h> #include <vespa/searchlib/queryeval/filter_wrapper.h> #include <vespa/searchlib/queryeval/orsearch.h> +#include <cmath> #include <memory> #include <type_traits> @@ -38,6 +39,43 @@ DirectMultiTermBlueprint<PostingStoreType, SearchType>::DirectMultiTermBlueprint template <typename PostingStoreType, typename SearchType> DirectMultiTermBlueprint<PostingStoreType, SearchType>::~DirectMultiTermBlueprint() = default; + +template <typename PostingStoreType, typename SearchType> +bool +DirectMultiTermBlueprint<PostingStoreType, SearchType>::use_hash_filter(bool strict) const +{ + if (strict || _iattr.hasMultiValue()) { + return false; + } + // The following very simplified formula was created after analysing performance of the IN operator + // on a 10M document corpus using a machine with an Intel Xeon 2.5 GHz CPU with 48 cores and 256 Gb of memory: + // https://github.com/vespa-engine/system-test/tree/master/tests/performance/in_operator + // + // The following 25 test cases were used to calculate the cost of using btree iterators (strict): + // op_hits_ratios = [5, 10, 50, 100, 200] * tokens_in_op = [1, 5, 10, 100, 1000] + // For each case we calculate the latency difference against the case with tokens_in_op=1 and the same op_hits_ratio. + // This indicates the extra time used to produce the same number of hits when having multiple tokens in the operator. + // The latency diff is divided with the number of hits produced and convert to nanoseconds: + // 10M * (op_hits_ratio / 1000) * 1000 * 1000 + // Based on the numbers we can approximate the cost per document (in nanoseconds) as: + // 8.0 (ns) * log2(tokens_in_op). + // NOTE: This is very simplified. Ideally we should also take into consideration the hit estimate of this blueprint, + // as the cost per document will be lower when producing few hits. + // + // In addition, the following 28 test cases were used to calculate the cost of using the hash filter (non-strict). + // filter_hits_ratios = [1, 5, 10, 50, 100, 150, 200] x op_hits_ratios = [200] x tokens_in_op = [5, 10, 100, 1000] + // The code was altered to always using the hash filter for non-strict iterators. + // For each case we calculate the latency difference against a case from above with tokens_in_op=1 that produce a similar number of hits. + // This indicates the extra time used to produce the same number of hits when using the hash filter. + // The latency diff is divided with the number of hits the test filter produces and convert to nanoseconds: + // 10M * (filter_hits_ratio / 1000) * 1000 * 1000 + // Based on the numbers we calculate the average cost per document (in nanoseconds) as 26.0 ns. + + float hash_filter_cost_per_doc_ns = 26.0; + float btree_iterator_cost_per_doc_ns = 8.0 * std::log2(_terms.size()); + return hash_filter_cost_per_doc_ns < btree_iterator_cost_per_doc_ns; +} + template <typename PostingStoreType, typename SearchType> typename DirectMultiTermBlueprint<PostingStoreType, SearchType>::IteratorWeights DirectMultiTermBlueprint<PostingStoreType, SearchType>::create_iterators(std::vector<IteratorType>& btree_iterators, @@ -88,7 +126,7 @@ DirectMultiTermBlueprint<PostingStoreType, SearchType>::combine_iterators(std::u } template <typename PostingStoreType, typename SearchType> -template <bool filter_search, bool need_match_data> +template <bool filter_search> std::unique_ptr<queryeval::SearchIterator> DirectMultiTermBlueprint<PostingStoreType, SearchType>::create_search_helper(const fef::TermFieldMatchDataArray& tfmda, bool strict) const @@ -96,34 +134,32 @@ DirectMultiTermBlueprint<PostingStoreType, SearchType>::create_search_helper(con if (_terms.empty()) { return std::make_unique<queryeval::EmptySearch>(); } + auto& tfmd = *tfmda[0]; + bool field_is_filter = getState().fields()[0].isFilter(); + if constexpr (SearchType::supports_hash_filter) { + if (use_hash_filter(strict)) { + return SearchType::create_hash_filter(tfmd, (filter_search || field_is_filter), + _weights, _terms, + _iattr, _attr, _dictionary_snapshot); + } + } std::vector<IteratorType> btree_iterators; std::vector<queryeval::SearchIterator::UP> bitvectors; const size_t num_children = _terms.size(); btree_iterators.reserve(num_children); - auto& tfmd = *tfmda[0]; bool use_bit_vector_when_available = filter_search || !_attr.has_always_btree_iterator(); auto weights = create_iterators(btree_iterators, bitvectors, use_bit_vector_when_available, tfmd, strict); - if constexpr (filter_search || (!need_match_data && !SearchType::require_btree_iterators)) { - auto filter = !btree_iterators.empty() ? - (need_match_data ? - attribute::MultiTermOrFilterSearch::create(std::move(btree_iterators), tfmd) : - attribute::MultiTermOrFilterSearch::create(std::move(btree_iterators))) : - std::unique_ptr<SearchIterator>(); - return combine_iterators(std::move(filter), std::move(bitvectors), strict); - } - bool field_is_filter = getState().fields()[0].isFilter(); - if constexpr (!filter_search && !SearchType::require_btree_iterators) { + if constexpr (!SearchType::require_btree_iterators) { auto multi_term = !btree_iterators.empty() ? - SearchType::create(tfmd, field_is_filter, std::move(weights), std::move(btree_iterators)) + SearchType::create(tfmd, (filter_search || field_is_filter), std::move(weights), std::move(btree_iterators)) : std::unique_ptr<SearchIterator>(); return combine_iterators(std::move(multi_term), std::move(bitvectors), strict); - } else if constexpr (SearchType::require_btree_iterators) { + } else { // In this case we should only have btree iterators. assert(btree_iterators.size() == _terms.size()); assert(weights.index() == 0); return SearchType::create(tfmd, field_is_filter, std::get<0>(weights).get(), std::move(btree_iterators)); } - return std::make_unique<queryeval::EmptySearch>(); } template <typename PostingStoreType, typename SearchType> @@ -132,12 +168,7 @@ DirectMultiTermBlueprint<PostingStoreType, SearchType>::createLeafSearch(const f { assert(tfmda.size() == 1); assert(getState().numFields() == 1); - bool need_match_data = !tfmda[0]->isNotNeeded(); - if (need_match_data) { - return create_search_helper<SearchType::filter_search, true>(tfmda, strict); - } else { - return create_search_helper<SearchType::filter_search, false>(tfmda, strict); - } + return create_search_helper<SearchType::filter_search>(tfmda, strict); } template <typename PostingStoreType, typename SearchType> @@ -146,7 +177,7 @@ DirectMultiTermBlueprint<PostingStoreType, SearchType>::createFilterSearch(bool { assert(getState().numFields() == 1); auto wrapper = std::make_unique<FilterWrapper>(getState().numFields()); - wrapper->wrap(create_search_helper<true, false>(wrapper->tfmda(), strict)); + wrapper->wrap(create_search_helper<true>(wrapper->tfmda(), strict)); return wrapper; } diff --git a/searchlib/src/vespa/searchlib/attribute/enum_store_loaders.cpp b/searchlib/src/vespa/searchlib/attribute/enum_store_loaders.cpp index 868c0013dd5..60129a9e577 100644 --- a/searchlib/src/vespa/searchlib/attribute/enum_store_loaders.cpp +++ b/searchlib/src/vespa/searchlib/attribute/enum_store_loaders.cpp @@ -4,6 +4,7 @@ #include "i_enum_store.h" #include "i_enum_store_dictionary.h" #include <vespa/vespalib/util/array.hpp> +#include <algorithm> namespace search::enumstore { diff --git a/searchlib/src/vespa/searchlib/attribute/in_term_search.h b/searchlib/src/vespa/searchlib/attribute/in_term_search.h index 36776499e51..f9a48af2aba 100644 --- a/searchlib/src/vespa/searchlib/attribute/in_term_search.h +++ b/searchlib/src/vespa/searchlib/attribute/in_term_search.h @@ -2,14 +2,20 @@ #pragma once +#include <vespa/searchlib/queryeval/weighted_set_term_search.h> + namespace search::attribute { /** - * Class used as template argument in DirectMultiTermBlueprint to configure it for the InTerm query operator. + * Search iterator for an InTerm, sharing the implementation with WeightedSetTerm. + * + * The only difference is that an InTerm never requires unpacking of weights. */ -struct InTermSearch { +class InTermSearch : public queryeval::WeightedSetTermSearch { +public: + // Whether this iterator is considered a filter, independent of attribute vector settings (ref. rank: filter). + // Setting this to true ensures that weights are never unpacked. static constexpr bool filter_search = true; - static constexpr bool require_btree_iterators = false; }; } diff --git a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h index 9c3ea258fdc..43500366f21 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h @@ -23,6 +23,7 @@ class MultiTermHashFilter final : public queryeval::SearchIterator public: using Key = typename WrapperType::TokenT; using TokenMap = vespalib::hash_map<Key, int32_t, vespalib::hash<Key>, std::equal_to<Key>, vespalib::hashtable_base::and_modulator>; + static constexpr bool unpack_weights = WrapperType::unpack_weights; private: fef::TermFieldMatchData& _tfmd; diff --git a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp index 96d5b3ac1f3..e67bda147d7 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp @@ -42,10 +42,14 @@ template <typename WrapperType> void MultiTermHashFilter<WrapperType>::doUnpack(uint32_t docId) { - _tfmd.reset(docId); - fef::TermFieldMatchDataPosition pos; - pos.setElementWeight(_weight); - _tfmd.appendPosition(pos); + if constexpr (unpack_weights) { + _tfmd.reset(docId); + fef::TermFieldMatchDataPosition pos; + pos.setElementWeight(_weight); + _tfmd.appendPosition(pos); + } else { + _tfmd.resetOnlyDocId(docId); + } } } diff --git a/searchlib/src/vespa/searchlib/attribute/multi_term_or_filter_search.h b/searchlib/src/vespa/searchlib/attribute/multi_term_or_filter_search.h index 1e8227c3007..f9357081d74 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_term_or_filter_search.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_term_or_filter_search.h @@ -9,8 +9,7 @@ namespace search::attribute { /** * Filter iterator on top of low-level posting list iterators or regular search iterators with OR semantics. * - * Used during calculation of global filter for InTerm, WeightedSetTerm, DotProduct and WandTerm, - * or when ranking is not needed for InTerm and WeightedSetTerm. + * Used during calculation of global filter for DotProduct and WandTerm. */ class MultiTermOrFilterSearch : public queryeval::SearchIterator { diff --git a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp index ea1058d88fb..9328857c919 100644 --- a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp @@ -89,7 +89,7 @@ template <typename B, typename M> const IDocidWithWeightPostingStore* MultiValueNumericPostingAttribute<B, M>::as_docid_with_weight_posting_store() const { - if (this->hasWeightedSetType() && (this->getBasicType() == AttributeVector::BasicType::INT64)) { + if (this->getConfig().basicType().is_integer_type()) { return &_posting_store_adapter; } return nullptr; diff --git a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp index b6e9b69a81d..371b8d920f9 100644 --- a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp @@ -108,8 +108,7 @@ template <typename B, typename T> const IDocidWithWeightPostingStore* MultiValueStringPostingAttributeT<B, T>::as_docid_with_weight_posting_store() const { - // TODO: Add support for handling bit vectors too, and lift restriction on isFilter. - if (this->hasWeightedSetType() && this->isStringType()) { + if (this->isStringType()) { return &_posting_store_adapter; } return nullptr; diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp index c57742ca4b6..c74e25fbf5a 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp @@ -150,22 +150,11 @@ SingleValueNumericPostingAttribute<B>::getSearch(QueryTermSimple::UP qTerm, return std::make_unique<SC>(std::move(base_sc), params, *this); } -namespace { - -bool is_integer_type(attribute::BasicType type) { - return (type == attribute::BasicType::INT8) || - (type == attribute::BasicType::INT16) || - (type == attribute::BasicType::INT32) || - (type == attribute::BasicType::INT64); -} - -} - template <typename B> const IDocidPostingStore* SingleValueNumericPostingAttribute<B>::as_docid_posting_store() const { - if (is_integer_type(this->getBasicType())) { + if (this->getConfig().basicType().is_integer_type()) { return &_posting_store_adapter; } return nullptr; diff --git a/searchlib/src/vespa/searchlib/attribute/string_search_helper.cpp b/searchlib/src/vespa/searchlib/attribute/string_search_helper.cpp index 75885aa0402..f1a643dc376 100644 --- a/searchlib/src/vespa/searchlib/attribute/string_search_helper.cpp +++ b/searchlib/src/vespa/searchlib/attribute/string_search_helper.cpp @@ -80,7 +80,8 @@ StringSearchHelper::isMatch(const char *src) const noexcept { return getRegex().valid() && getRegex().partial_match(std::string_view(src)); } if (__builtin_expect(isFuzzy(), false)) { - return _dfa_fuzzy_matcher ? _dfa_fuzzy_matcher->is_match(src) : getFuzzyMatcher().isMatch(src); + return _dfa_fuzzy_matcher ? _dfa_fuzzy_matcher->is_match(std::string_view(src)) + : getFuzzyMatcher().isMatch(std::string_view(src)); } if (__builtin_expect(isCased(), false)) { int res = strncmp(_term, src, _termLen); diff --git a/searchlib/src/vespa/searchlib/common/geo_gcd.cpp b/searchlib/src/vespa/searchlib/common/geo_gcd.cpp index 194adf015b6..a9c2cda664a 100644 --- a/searchlib/src/vespa/searchlib/common/geo_gcd.cpp +++ b/searchlib/src/vespa/searchlib/common/geo_gcd.cpp @@ -12,6 +12,11 @@ static constexpr double earth_mean_radius = 6371.0088; static constexpr double degrees_to_radians = M_PI / 180.0; +static constexpr double internal_from_km = (1.0e6 * 180.0) / (M_PI * earth_mean_radius); + +double greatCircleDistance(double theta_A, double phi_A, + double theta_B, double phi_B) __attribute__((noinline)); + // with input in radians double greatCircleDistance(double theta_A, double phi_A, double theta_B, double phi_B) @@ -47,4 +52,8 @@ double GeoGcd::km_great_circle_distance(double lat, double lng) const { return greatCircleDistance(theta_A, phi_A, theta_B, phi_B); } +double GeoGcd::km_to_internal(double km) { + return km * internal_from_km; +} + } // namespace search::common diff --git a/searchlib/src/vespa/searchlib/common/geo_gcd.h b/searchlib/src/vespa/searchlib/common/geo_gcd.h index f829625ce5d..8dd003727cc 100644 --- a/searchlib/src/vespa/searchlib/common/geo_gcd.h +++ b/searchlib/src/vespa/searchlib/common/geo_gcd.h @@ -23,6 +23,7 @@ struct GeoGcd } double km_great_circle_distance(double lat, double lng) const; + static double km_to_internal(double km); private: double _latitude_radians; double _longitude_radians; diff --git a/searchlib/src/vespa/searchlib/features/distancefeature.cpp b/searchlib/src/vespa/searchlib/features/distancefeature.cpp index fd84fdb9ccb..15362b6a224 100644 --- a/searchlib/src/vespa/searchlib/features/distancefeature.cpp +++ b/searchlib/src/vespa/searchlib/features/distancefeature.cpp @@ -6,6 +6,7 @@ #include <vespa/document/datatype/positiondatatype.h> #include <vespa/searchcommon/common/schema.h> #include <vespa/searchlib/common/geo_location_spec.h> +#include <vespa/searchlib/common/geo_gcd.h> #include <vespa/searchlib/fef/matchdata.h> #include <vespa/searchlib/tensor/distance_calculator.h> #include <vespa/vespalib/geo/zcurve.h> @@ -69,79 +70,103 @@ ConvertRawscoreToDistance::execute(uint32_t docId) outputs().set_number(0, min_distance); } +const feature_t DistanceExecutor::DEFAULT_DISTANCE(6400000000.0); -feature_t -DistanceExecutor::calculateDistance(uint32_t docId) -{ - _best_index = -1.0; - _best_x = -180.0 * 1.0e6; - _best_y = 90.0 * 1.0e6; - if ((! _locations.empty()) && (_pos != nullptr)) { - LOG(debug, "calculate 2D Z-distance from %zu locations", _locations.size()); - return calculate2DZDistance(docId); - } - return DEFAULT_DISTANCE; -} +/** + * Implements the executor for the great circle distance feature. + */ +class GeoGCDExecutor : public fef::FeatureExecutor { +private: + std::vector<search::common::GeoGcd> _locations; + const attribute::IAttributeVector * _pos; + attribute::IntegerContent _intBuf; + feature_t _best_index; + feature_t _best_lat; + feature_t _best_lng; + feature_t calculateGeoGCD(uint32_t docId); +public: + /** + * Constructs an executor for the GeoGCD feature. + * + * @param locations location objects associated with the query environment. + * @param pos the attribute to use for positions (expects zcurve encoding). + */ + GeoGCDExecutor(GeoLocationSpecPtrs locations, const attribute::IAttributeVector * pos); + void execute(uint32_t docId) override; +}; -feature_t -DistanceExecutor::calculate2DZDistance(uint32_t docId) -{ + +feature_t GeoGCDExecutor::calculateGeoGCD(uint32_t docId) { + feature_t dist = std::numeric_limits<feature_t>::max(); + _best_index = -1; + _best_lat = 90.0; + _best_lng = -180.0; + if (_locations.empty()) { + return dist; + } _intBuf.fill(*_pos, docId); uint32_t numValues = _intBuf.size(); - uint64_t sqabsdist = std::numeric_limits<uint64_t>::max(); int32_t docx = 0; int32_t docy = 0; for (auto loc : _locations) { - assert(loc); - assert(loc->location.valid()); for (uint32_t i = 0; i < numValues; ++i) { vespalib::geo::ZCurve::decode(_intBuf[i], &docx, &docy); - uint64_t sqdist = loc->location.sq_distance_to({docx, docy}); - if (sqdist < sqabsdist) { + double lat = docy / 1.0e6; + double lng = docx / 1.0e6; + double d = loc.km_great_circle_distance(lat, lng); + if (d < dist) { + dist = d; _best_index = i; - _best_x = docx; - _best_y = docy; - sqabsdist = sqdist; + _best_lat = lat; + _best_lng = lng; } } } - return static_cast<feature_t>(std::sqrt(static_cast<feature_t>(sqabsdist))); + return dist; } -DistanceExecutor::DistanceExecutor(GeoLocationSpecPtrs locations, - const search::attribute::IAttributeVector * pos) : - FeatureExecutor(), - _locations(locations), - _pos(pos), - _intBuf() +GeoGCDExecutor::GeoGCDExecutor(GeoLocationSpecPtrs locations, const attribute::IAttributeVector * pos) + : FeatureExecutor(), + _locations(), + _pos(pos), + _intBuf() { - if (_pos != nullptr) { - _intBuf.allocate(_pos->getMaxValueCount()); + if (_pos == nullptr) { + return; + } + _intBuf.allocate(_pos->getMaxValueCount()); + for (const auto * p : locations) { + if (p && p->location.valid() && p->location.has_point) { + double lat = p->location.point.y / 1.0e6; + double lng = p->location.point.x / 1.0e6; + _locations.emplace_back(search::common::GeoGcd{lat, lng}); + } } } + void -DistanceExecutor::execute(uint32_t docId) +GeoGCDExecutor::execute(uint32_t docId) { - static constexpr double earth_mean_radius = 6371.0088; - static constexpr double deg_to_rad = M_PI / 180.0; - static constexpr double km_from_internal = 1.0e-6 * deg_to_rad * earth_mean_radius; - feature_t internal_d = calculateDistance(docId); - outputs().set_number(0, internal_d); + double dist_km = calculateGeoGCD(docId); + double micro_degrees = search::common::GeoGcd::km_to_internal(dist_km); + if (_best_index < 0) { + dist_km = 40000.0; + micro_degrees = DistanceExecutor::DEFAULT_DISTANCE; + } + outputs().set_number(0, micro_degrees); outputs().set_number(1, _best_index); - outputs().set_number(2, _best_y * 1.0e-6); // latitude - outputs().set_number(3, _best_x * 1.0e-6); // longitude - outputs().set_number(4, internal_d * km_from_internal); // km + outputs().set_number(2, _best_lat); // latitude + outputs().set_number(3, _best_lng); // longitude + outputs().set_number(4, dist_km); } -const feature_t DistanceExecutor::DEFAULT_DISTANCE(6400000000.0); - - DistanceBlueprint::DistanceBlueprint() : Blueprint("distance"), _field_name(), - _arg_string(), + _label_name(), + _attr_name(), _attr_id(search::index::Schema::UNKNOWN_FIELD_ID), _use_geo_pos(false), _use_nns_tensor(false), @@ -166,7 +191,7 @@ DistanceBlueprint::createInstance() const bool DistanceBlueprint::setup_geopos(const vespalib::string &attr) { - _arg_string = attr; + _attr_name = attr; _use_geo_pos = true; describeOutput("out", "The euclidean distance from the query position."); describeOutput("index", "Index in array of closest point"); @@ -179,7 +204,7 @@ DistanceBlueprint::setup_geopos(const vespalib::string &attr) bool DistanceBlueprint::setup_nns(const vespalib::string &attr) { - _arg_string = attr; + _attr_name = attr; _use_nns_tensor = true; describeOutput("out", "The euclidean distance from the query position."); return true; @@ -195,7 +220,7 @@ DistanceBlueprint::setup(const IIndexEnvironment & env, // params[0] = field / label // params[1] = attribute name / label value if (arg == "label") { - _arg_string = params[1].getValue(); + _label_name = params[1].getValue(); _use_item_label = true; describeOutput("out", "The euclidean distance from the labeled query item."); return true; @@ -241,7 +266,7 @@ DistanceBlueprint::prepareSharedState(const fef::IQueryEnvironment& env, fef::IO DistanceCalculatorBundle::prepare_shared_state(env, store, _attr_id, "distance"); } if (_use_item_label) { - DistanceCalculatorBundle::prepare_shared_state(env, store, _arg_string, "distance"); + DistanceCalculatorBundle::prepare_shared_state(env, store, _label_name, "distance"); } } @@ -252,7 +277,7 @@ DistanceBlueprint::createExecutor(const IQueryEnvironment &env, vespalib::Stash return stash.create<ConvertRawscoreToDistance>(env, _attr_id); } if (_use_item_label) { - return stash.create<ConvertRawscoreToDistance>(env, _arg_string); + return stash.create<ConvertRawscoreToDistance>(env, _label_name); } // expect geo pos: const search::attribute::IAttributeVector * pos = nullptr; @@ -261,42 +286,41 @@ DistanceBlueprint::createExecutor(const IQueryEnvironment &env, vespalib::Stash for (auto loc_ptr : env.getAllLocations()) { if (_use_geo_pos && loc_ptr && loc_ptr->location.valid()) { - if (loc_ptr->field_name == _arg_string || + if (loc_ptr->field_name == _attr_name || loc_ptr->field_name == _field_name) { - LOG(debug, "found loc from query env matching '%s'", _arg_string.c_str()); + LOG(debug, "found loc from query env matching '%s'", _attr_name.c_str()); matching_locs.push_back(loc_ptr); } else { LOG(debug, "found loc(%s) from query env not matching arg(%s)", - loc_ptr->field_name.c_str(), _arg_string.c_str()); + loc_ptr->field_name.c_str(), _attr_name.c_str()); other_locs.push_back(loc_ptr); } } } if (matching_locs.empty() && other_locs.empty()) { LOG(debug, "createExecutor: no valid locations"); - return stash.create<DistanceExecutor>(matching_locs, nullptr); + return stash.create<GeoGCDExecutor>(matching_locs, nullptr); } - LOG(debug, "createExecutor: valid location, attribute='%s'", _arg_string.c_str()); - + LOG(debug, "createExecutor: valid location, attribute='%s'", _attr_name.c_str()); if (_use_geo_pos) { - pos = env.getAttributeContext().getAttribute(_arg_string); + pos = env.getAttributeContext().getAttribute(_attr_name); if (pos != nullptr) { if (!pos->isIntegerType()) { - Issue::report("distance feature: The position attribute '%s' is not an integer attribute. Will use default distance.", + Issue::report("distance feature: The position attribute '%s' is not an integer attribute.", pos->getName().c_str()); pos = nullptr; } else if (pos->getCollectionType() == attribute::CollectionType::WSET) { - Issue::report("distance feature: The position attribute '%s' is a weighted set attribute. Will use default distance.", + Issue::report("distance feature: The position attribute '%s' is a weighted set attribute.", pos->getName().c_str()); pos = nullptr; } } else { - Issue::report("distance feature: The position attribute '%s' was not found. Will use default distance.", _arg_string.c_str()); + Issue::report("distance feature: The position attribute '%s' was not found.", _attr_name.c_str()); } } LOG(debug, "use '%s' locations with pos=%p", matching_locs.empty() ? "other" : "matching", pos); - return stash.create<DistanceExecutor>(matching_locs.empty() ? other_locs : matching_locs, pos); + return stash.create<GeoGCDExecutor>(matching_locs.empty() ? other_locs : matching_locs, pos); } } diff --git a/searchlib/src/vespa/searchlib/features/distancefeature.h b/searchlib/src/vespa/searchlib/features/distancefeature.h index 6fc4665117a..7d5caad482d 100644 --- a/searchlib/src/vespa/searchlib/features/distancefeature.h +++ b/searchlib/src/vespa/searchlib/features/distancefeature.h @@ -13,29 +13,8 @@ using GeoLocationSpecPtrs = std::vector<const search::common::GeoLocationSpec *> /** * Implements the executor for the distance feature. */ -class DistanceExecutor : public fef::FeatureExecutor { -private: - GeoLocationSpecPtrs _locations; - const attribute::IAttributeVector * _pos; - attribute::IntegerContent _intBuf; - feature_t _best_index; - feature_t _best_x; - feature_t _best_y; - - feature_t calculateDistance(uint32_t docId); - feature_t calculate2DZDistance(uint32_t docId); - +class DistanceExecutor { public: - /** - * Constructs an executor for the distance feature. - * - * @param locations location objects associated with the query environment. - * @param pos the attribute to use for positions (expects zcurve encoding). - */ - DistanceExecutor(GeoLocationSpecPtrs locations, - const attribute::IAttributeVector * pos); - void execute(uint32_t docId) override; - static const feature_t DEFAULT_DISTANCE; }; @@ -45,7 +24,8 @@ public: class DistanceBlueprint : public fef::Blueprint { private: vespalib::string _field_name; - vespalib::string _arg_string; + vespalib::string _label_name; + vespalib::string _attr_name; uint32_t _attr_id; bool _use_geo_pos; bool _use_nns_tensor; diff --git a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt index 05a75f4662e..63d52cbdf9f 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt @@ -2,13 +2,19 @@ vespa_add_library(searchlib_query_streaming OBJECT SOURCES dot_product_term.cpp + fuzzy_term.cpp + hit_iterator_pack.cpp in_term.cpp multi_term.cpp + near_query_node.cpp nearest_neighbor_query_node.cpp + onear_query_node.cpp + phrase_query_node.cpp query.cpp querynode.cpp querynoderesultbase.cpp queryterm.cpp + same_element_query_node.cpp wand_term.cpp weighted_set_term.cpp regexp_term.cpp diff --git a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp index 1871bda564d..09840d9a126 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp @@ -24,7 +24,7 @@ DotProductTerm::build_scores(Scores& scores) const for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - scores[hit.context()] += ((int64_t)term->weight().percent()) * hit.weight(); + scores[hit.field_id()] += ((int64_t)term->weight().percent()) * hit.element_weight(); } } } diff --git a/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.cpp new file mode 100644 index 00000000000..f33fe44369a --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.cpp @@ -0,0 +1,43 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "fuzzy_term.h" + +namespace search::streaming { + +namespace { + +constexpr bool normalizing_implies_cased(Normalizing norm) noexcept { + return (norm == Normalizing::NONE); +} + +} + +FuzzyTerm::FuzzyTerm(std::unique_ptr<QueryNodeResultBase> result_base, stringref term, + const string& index, Type type, Normalizing normalizing, + uint8_t max_edits, uint32_t prefix_size) + : QueryTerm(std::move(result_base), term, index, type, normalizing), + _dfa_matcher(), + _fallback_matcher() +{ + setFuzzyMaxEditDistance(max_edits); + setFuzzyPrefixLength(prefix_size); + + std::string_view term_view(term.data(), term.size()); + const bool cased = normalizing_implies_cased(normalizing); + if (attribute::DfaFuzzyMatcher::supports_max_edits(max_edits)) { + _dfa_matcher = std::make_unique<attribute::DfaFuzzyMatcher>(term_view, max_edits, prefix_size, cased); + } else { + _fallback_matcher = std::make_unique<vespalib::FuzzyMatcher>(term_view, max_edits, prefix_size, cased); + } +} + +FuzzyTerm::~FuzzyTerm() = default; + +bool FuzzyTerm::is_match(std::string_view term) const { + if (_dfa_matcher) { + return _dfa_matcher->is_match(term); + } else { + return _fallback_matcher->isMatch(term); + } +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.h b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.h new file mode 100644 index 00000000000..c6c88b18969 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.h @@ -0,0 +1,34 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "queryterm.h" +#include <vespa/searchlib/attribute/dfa_fuzzy_matcher.h> +#include <vespa/vespalib/fuzzy/fuzzy_matcher.h> +#include <memory> +#include <string_view> + +namespace search::streaming { + +/** + * Query term that matches candidate field terms that are within a query-specified + * maximum number of edits (add, delete or substitute a character), with case + * sensitivity controlled by the provided Normalizing mode. + * + * Optionally, terms may be prefixed-locked, which enforces field terms to have a + * particular prefix and where edits are only counted for the remaining term suffix. + */ +class FuzzyTerm : public QueryTerm { + std::unique_ptr<attribute::DfaFuzzyMatcher> _dfa_matcher; + std::unique_ptr<vespalib::FuzzyMatcher> _fallback_matcher; +public: + FuzzyTerm(std::unique_ptr<QueryNodeResultBase> result_base, stringref term, + const string& index, Type type, Normalizing normalizing, + uint8_t max_edits, uint32_t prefix_size); + ~FuzzyTerm() override; + + [[nodiscard]] FuzzyTerm* as_fuzzy_term() noexcept override { return this; } + + [[nodiscard]] bool is_match(std::string_view term) const; +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/hit.h b/searchlib/src/vespa/searchlib/query/streaming/hit.h index a798d293491..168c09a91ec 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/hit.h +++ b/searchlib/src/vespa/searchlib/query/streaming/hit.h @@ -8,23 +8,25 @@ namespace search::streaming { class Hit { + uint32_t _field_id; + uint32_t _element_id; + int32_t _element_weight; + uint32_t _element_length; + uint32_t _position; public: - Hit(uint32_t pos_, uint32_t context_, uint32_t elemId_, int32_t weight_) noexcept - : _position(pos_ | (context_<<24)), - _elemId(elemId_), - _weight(weight_) + Hit(uint32_t field_id_, uint32_t element_id_, int32_t element_weight_, uint32_t position_) noexcept + : _field_id(field_id_), + _element_id(element_id_), + _element_weight(element_weight_), + _element_length(0), + _position(position_) { } - int32_t weight() const { return _weight; } - uint32_t pos() const { return _position; } - uint32_t wordpos() const { return _position & 0xffffff; } - uint32_t context() const { return _position >> 24; } - uint32_t elemId() const { return _elemId; } - bool operator < (const Hit & b) const { return cmp(b) < 0; } -private: - int cmp(const Hit & b) const { return _position - b._position; } - uint32_t _position; - uint32_t _elemId; - int32_t _weight; + uint32_t field_id() const noexcept { return _field_id; } + uint32_t element_id() const { return _element_id; } + int32_t element_weight() const { return _element_weight; } + uint32_t element_length() const { return _element_length; } + uint32_t position() const { return _position; } + void set_element_length(uint32_t value) { _element_length = value; } }; using HitList = std::vector<Hit>; diff --git a/searchlib/src/vespa/searchlib/query/streaming/hit_iterator.h b/searchlib/src/vespa/searchlib/query/streaming/hit_iterator.h new file mode 100644 index 00000000000..21eba679abd --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/hit_iterator.h @@ -0,0 +1,73 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "hit.h" + +namespace search::streaming { + +/* + * Iterator used over hit list for a term to support near, onear, phrase and + * same element query nodes. + */ +class HitIterator +{ + HitList::const_iterator _cur; + HitList::const_iterator _end; +public: + using FieldElement = std::pair<uint32_t, uint32_t>; + HitIterator(const HitList& hl) noexcept + : _cur(hl.begin()), + _end(hl.end()) + { } + bool valid() const noexcept { return _cur != _end; } + const Hit* operator->() const noexcept { return _cur.operator->(); } + const Hit& operator*() const noexcept { return _cur.operator*(); } + FieldElement get_field_element() const noexcept { return std::make_pair(_cur->field_id(), _cur->element_id()); } + bool seek_to_field_element(const FieldElement& field_element) noexcept { + while (valid()) { + if (!(get_field_element() < field_element)) { + return true; + } + ++_cur; + } + return false; + } + /* + * Step iterator forwards within the scope of the same field + * element. Return true if iterator is valid and with the same + * field element, otherwise return false but update field_element + * if iterator is valid to prepare for hit iterator pack seeking + * to next matching field element. + */ + bool step_in_field_element(FieldElement& field_element) noexcept { + ++_cur; + if (!valid()) { + return false; + } + auto ife = get_field_element(); + if (field_element < ife) { + field_element = ife; + return false; + } + return true; + } + /* + * Seek to position within the scope of the same field element. + * Return true if iterator is valid and with the same field + * element, otherwise return false but update field element if + * iterator is valid to prepare for hit iterator pack seeking to + * next matching field element. + */ + bool seek_in_field_element(uint32_t position, FieldElement& field_element) { + while (_cur->position() < position) { + if (!step_in_field_element(field_element)) { + return false; + } + } + return true; + } + HitIterator& operator++() { ++_cur; return *this; } +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/hit_iterator_pack.cpp b/searchlib/src/vespa/searchlib/query/streaming/hit_iterator_pack.cpp new file mode 100644 index 00000000000..fabd992c379 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/hit_iterator_pack.cpp @@ -0,0 +1,57 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "hit_iterator_pack.h" + +namespace search::streaming { + + +HitIteratorPack::HitIteratorPack(const QueryNodeList& children) + : _iterators(), + _field_element(std::make_pair(0, 0)) +{ + auto num_children = children.size(); + _iterators.reserve(num_children); + for (auto& child : children) { + auto& curr = dynamic_cast<const QueryTerm&>(*child); + _iterators.emplace_back(curr.getHitList()); + } +} + +HitIteratorPack::~HitIteratorPack() = default; + +bool +HitIteratorPack::all_valid() const noexcept +{ + if (_iterators.empty()) { + return false; + } + for (auto& it : _iterators) { + if (!it.valid()) { + return false; + } + } + return true; +} + +bool +HitIteratorPack::seek_to_matching_field_element() noexcept +{ + bool retry = true; + while (retry) { + retry = false; + for (auto& it : _iterators) { + if (!it.seek_to_field_element(_field_element)) { + return false; + } + auto ife = it.get_field_element(); + if (_field_element < ife) { + _field_element = ife; + retry = true; + break; + } + } + } + return true; +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/hit_iterator_pack.h b/searchlib/src/vespa/searchlib/query/streaming/hit_iterator_pack.h new file mode 100644 index 00000000000..200d579930a --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/hit_iterator_pack.h @@ -0,0 +1,32 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "hit_iterator.h" +#include "queryterm.h" + +namespace search::streaming { + +/* + * Iterator pack used over hit list for a term to support near, onear, + * phrase and same element query nodes. + */ +class HitIteratorPack +{ + using iterator = typename std::vector<HitIterator>::iterator; + using FieldElement = HitIterator::FieldElement; + std::vector<HitIterator> _iterators; + FieldElement _field_element; +public: + explicit HitIteratorPack(const QueryNodeList& children); + ~HitIteratorPack(); + FieldElement& get_field_element_ref() noexcept { return _field_element; } + HitIterator& front() noexcept { return _iterators.front(); } + HitIterator& back() noexcept { return _iterators.back(); } + iterator begin() noexcept { return _iterators.begin(); } + iterator end() noexcept { return _iterators.end(); } + bool all_valid() const noexcept; + bool seek_to_matching_field_element() noexcept; +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp index 3e75f4a5114..c164db69ba1 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp @@ -29,9 +29,9 @@ InTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_ for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - if (!prev_field_id.has_value() || prev_field_id.value() != hit.context()) { - prev_field_id = hit.context(); - matching_field_ids.insert(hit.context()); + if (!prev_field_id.has_value() || prev_field_id.value() != hit.field_id()) { + prev_field_id = hit.field_id(); + matching_field_ids.insert(hit.field_id()); } } } diff --git a/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp index dd34b9b7e73..f5a09892551 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp @@ -51,11 +51,4 @@ MultiTerm::evaluate() const return false; } -const HitList& -MultiTerm::evaluateHits(HitList& hl) const -{ - hl.clear(); - return hl; -} - } diff --git a/searchlib/src/vespa/searchlib/query/streaming/multi_term.h b/searchlib/src/vespa/searchlib/query/streaming/multi_term.h index 3bb69e29693..6f795c31356 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/multi_term.h +++ b/searchlib/src/vespa/searchlib/query/streaming/multi_term.h @@ -32,7 +32,6 @@ public: MultiTerm* as_multi_term() noexcept override { return this; } void reset() override; bool evaluate() const override; - const HitList& evaluateHits(HitList& hl) const override; virtual void unpack_match_data(uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) = 0; const std::vector<std::unique_ptr<QueryTerm>>& get_terms() const noexcept { return _terms; } }; diff --git a/searchlib/src/vespa/searchlib/query/streaming/near_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/near_query_node.cpp new file mode 100644 index 00000000000..a777841bd70 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/near_query_node.cpp @@ -0,0 +1,65 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "near_query_node.h" +#include "hit_iterator_pack.h" +#include <vespa/vespalib/objects/visit.hpp> + +namespace search::streaming { + +template <bool ordered> +bool +NearQueryNode::evaluate_helper() const +{ + HitIteratorPack itr_pack(getChildren()); + if (!itr_pack.all_valid()) { + return false; + } + while (itr_pack.seek_to_matching_field_element()) { + uint32_t min_position = 0; + if (itr_pack.front()->position() > min_position + distance()) { + min_position = itr_pack.front()->position() - distance(); + } + bool retry_element = true; + while (retry_element) { + bool match = true; + uint32_t min_next_position = min_position; + for (auto& it : itr_pack) { + if (!it.seek_in_field_element(min_next_position, itr_pack.get_field_element_ref())) { + retry_element = false; + match = false; + break; + } + if (it->position() > min_position + distance()) { + min_position = it->position() - distance(); + match = false; + break; + } + if constexpr (ordered) { + min_next_position = it->position() + 1; + } + } + if (match) { + return true; + } + } + } + return false; +} + +bool +NearQueryNode::evaluate() const +{ + return evaluate_helper<false>(); +} + +void +NearQueryNode::visitMembers(vespalib::ObjectVisitor &visitor) const +{ + AndQueryNode::visitMembers(visitor); + visit(visitor, "distance", static_cast<uint64_t>(_distance)); +} + +template bool NearQueryNode::evaluate_helper<false>() const; +template bool NearQueryNode::evaluate_helper<true>() const; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/near_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/near_query_node.h new file mode 100644 index 00000000000..9258c3efe27 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/near_query_node.h @@ -0,0 +1,29 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "query.h" + +namespace search::streaming { + +/** + N-ary Near operator. All terms must be within the given distance. +*/ +class NearQueryNode : public AndQueryNode +{ +protected: + template <bool ordered> + bool evaluate_helper() const; +public: + NearQueryNode() noexcept : AndQueryNode("NEAR"), _distance(0) { } + explicit NearQueryNode(const char * opName) noexcept : AndQueryNode(opName), _distance(0) { } + bool evaluate() const override; + void distance(size_t dist) { _distance = dist; } + size_t distance() const { return _distance; } + void visitMembers(vespalib::ObjectVisitor &visitor) const override; + bool isFlattenable(ParseItem::ItemType) const override { return false; } +private: + size_t _distance; +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/onear_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/onear_query_node.cpp new file mode 100644 index 00000000000..45c04b411dc --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/onear_query_node.cpp @@ -0,0 +1,14 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "onear_query_node.h" +#include "hit_iterator_pack.h" + +namespace search::streaming { + +bool +ONearQueryNode::evaluate() const +{ + return evaluate_helper<true>(); +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/onear_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/onear_query_node.h new file mode 100644 index 00000000000..649496b62d9 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/onear_query_node.h @@ -0,0 +1,20 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "near_query_node.h" + +namespace search::streaming { + +/** + N-ary Ordered near operator. The terms must be in order and the distance between + the first and last must not exceed the given distance. +*/ +class ONearQueryNode : public NearQueryNode +{ +public: + ONearQueryNode() noexcept : NearQueryNode("ONEAR") { } + bool evaluate() const override; +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/phrase_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/phrase_query_node.cpp new file mode 100644 index 00000000000..2d2778417fa --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/phrase_query_node.cpp @@ -0,0 +1,82 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "phrase_query_node.h" +#include "hit_iterator_pack.h" +#include <cassert> + +namespace search::streaming { + +bool +PhraseQueryNode::evaluate() const +{ + HitList hl; + return ! evaluateHits(hl).empty(); +} + +void PhraseQueryNode::getPhrases(QueryNodeRefList & tl) { tl.push_back(this); } +void PhraseQueryNode::getPhrases(ConstQueryNodeRefList & tl) const { tl.push_back(this); } + +void +PhraseQueryNode::addChild(QueryNode::UP child) { + assert(dynamic_cast<const QueryTerm *>(child.get()) != nullptr); + AndQueryNode::addChild(std::move(child)); +} + +const HitList & +PhraseQueryNode::evaluateHits(HitList & hl) const +{ + hl.clear(); + _fieldInfo.clear(); + HitIteratorPack itr_pack(getChildren()); + if (!itr_pack.all_valid()) { + return hl; + } + auto& last_child = dynamic_cast<const QueryTerm&>(*(*this)[size() - 1]); + while (itr_pack.seek_to_matching_field_element()) { + uint32_t first_position = itr_pack.front()->position(); + bool retry_element = true; + while (retry_element) { + uint32_t position_offset = 0; + bool match = true; + for (auto& it : itr_pack) { + if (!it.seek_in_field_element(first_position + position_offset, itr_pack.get_field_element_ref())) { + retry_element = false; + match = false; + break; + } + if (it->position() > first_position + position_offset) { + first_position = it->position() - position_offset; + match = false; + break; + } + ++position_offset; + } + if (match) { + auto h = *itr_pack.back(); + hl.push_back(h); + auto& fi = last_child.getFieldInfo(h.field_id()); + updateFieldInfo(h.field_id(), hl.size() - 1, fi.getFieldLength()); + if (!itr_pack.front().step_in_field_element(itr_pack.get_field_element_ref())) { + retry_element = false; + } + } + } + } + return hl; +} + +void +PhraseQueryNode::updateFieldInfo(size_t fid, size_t offset, size_t fieldLength) const +{ + if (fid >= _fieldInfo.size()) { + _fieldInfo.resize(fid + 1); + // only set hit offset and field length the first time + QueryTerm::FieldInfo & fi = _fieldInfo[fid]; + fi.setHitOffset(offset); + fi.setFieldLength(fieldLength); + } + QueryTerm::FieldInfo & fi = _fieldInfo[fid]; + fi.setHitCount(fi.getHitCount() + 1); +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/phrase_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/phrase_query_node.h new file mode 100644 index 00000000000..b137f813150 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/phrase_query_node.h @@ -0,0 +1,30 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "query.h" + +namespace search::streaming { + +/** + N-ary phrase operator. All terms must be satisfied and have the correct order + with distance to next term equal to 1. +*/ +class PhraseQueryNode : public AndQueryNode +{ +public: + PhraseQueryNode() noexcept : AndQueryNode("PHRASE"), _fieldInfo(32) { } + bool evaluate() const override; + const HitList & evaluateHits(HitList & hl) const override; + void getPhrases(QueryNodeRefList & tl) override; + void getPhrases(ConstQueryNodeRefList & tl) const override; + const QueryTerm::FieldInfo & getFieldInfo(size_t fid) const { return _fieldInfo[fid]; } + size_t getFieldInfoSize() const { return _fieldInfo.size(); } + bool isFlattenable(ParseItem::ItemType) const override { return false; } + void addChild(QueryNode::UP child) override; +private: + mutable std::vector<QueryTerm::FieldInfo> _fieldInfo; + void updateFieldInfo(size_t fid, size_t offset, size_t fieldLength) const; +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.cpp b/searchlib/src/vespa/searchlib/query/streaming/query.cpp index ca742aabe26..5b0076a30c5 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/query.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/query.cpp @@ -1,5 +1,9 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "query.h" +#include "near_query_node.h" +#include "onear_query_node.h" +#include "phrase_query_node.h" +#include "same_element_query_node.h" #include <vespa/searchlib/parsequery/stackdumpiterator.h> #include <vespa/vespalib/objects/visit.hpp> #include <cassert> @@ -31,7 +35,7 @@ const HitList & QueryConnector::evaluateHits(HitList & hl) const { if (evaluate()) { - hl.emplace_back(1, 0, 0, 1); + hl.emplace_back(0, 0, 1, 1); } return hl; } @@ -113,6 +117,7 @@ QueryConnector::create(ParseItem::ItemType type) case search::ParseItem::ITEM_SAME_ELEMENT: return std::make_unique<SameElementQueryNode>(); case search::ParseItem::ITEM_NEAR: return std::make_unique<NearQueryNode>(); case search::ParseItem::ITEM_ONEAR: return std::make_unique<ONearQueryNode>(); + case search::ParseItem::ITEM_RANK: return std::make_unique<RankWithQueryNode>(); default: return nullptr; } } @@ -158,174 +163,23 @@ OrQueryNode::evaluate() const { return false; } - -bool -EquivQueryNode::evaluate() const -{ - return OrQueryNode::evaluate(); -} - bool -SameElementQueryNode::evaluate() const { - HitList hl; - return ! evaluateHits(hl).empty(); -} - -void -SameElementQueryNode::addChild(QueryNode::UP child) { - assert(dynamic_cast<const QueryTerm *>(child.get()) != nullptr); - AndQueryNode::addChild(std::move(child)); -} - -const HitList & -SameElementQueryNode::evaluateHits(HitList & hl) const -{ - hl.clear(); - if ( !AndQueryNode::evaluate()) return hl; - - HitList tmpHL; - const auto & children = getChildren(); - unsigned int numFields = children.size(); - unsigned int currMatchCount = 0; - std::vector<unsigned int> indexVector(numFields, 0); - auto curr = static_cast<const QueryTerm *> (children[currMatchCount].get()); - bool exhausted( curr->evaluateHits(tmpHL).empty()); - for (; !exhausted; ) { - auto next = static_cast<const QueryTerm *>(children[currMatchCount+1].get()); - unsigned int & currIndex = indexVector[currMatchCount]; - unsigned int & nextIndex = indexVector[currMatchCount+1]; - - const auto & currHit = curr->evaluateHits(tmpHL)[currIndex]; - uint32_t currElemId = currHit.elemId(); - - const HitList & nextHL = next->evaluateHits(tmpHL); - - size_t nextIndexMax = nextHL.size(); - while ((nextIndex < nextIndexMax) && (nextHL[nextIndex].elemId() < currElemId)) { - nextIndex++; - } - if ((nextIndex < nextIndexMax) && (nextHL[nextIndex].elemId() == currElemId)) { - currMatchCount++; - if ((currMatchCount+1) == numFields) { - Hit h = nextHL[indexVector[currMatchCount]]; - hl.emplace_back(0, h.context(), h.elemId(), h.weight()); - currMatchCount = 0; - indexVector[0]++; - } - } else { - currMatchCount = 0; - indexVector[currMatchCount]++; - } - curr = static_cast<const QueryTerm *>(children[currMatchCount].get()); - exhausted = (nextIndex >= nextIndexMax) || (indexVector[currMatchCount] >= curr->evaluateHits(tmpHL).size()); - } - return hl; -} - -bool -PhraseQueryNode::evaluate() const -{ - HitList hl; - return ! evaluateHits(hl).empty(); -} - -void PhraseQueryNode::getPhrases(QueryNodeRefList & tl) { tl.push_back(this); } -void PhraseQueryNode::getPhrases(ConstQueryNodeRefList & tl) const { tl.push_back(this); } - -void -PhraseQueryNode::addChild(QueryNode::UP child) { - assert(dynamic_cast<const QueryTerm *>(child.get()) != nullptr); - AndQueryNode::addChild(std::move(child)); -} - -const HitList & -PhraseQueryNode::evaluateHits(HitList & hl) const -{ - hl.clear(); - _fieldInfo.clear(); - if ( ! AndQueryNode::evaluate()) return hl; - - HitList tmpHL; - const auto & children = getChildren(); - unsigned int fullPhraseLen = children.size(); - unsigned int currPhraseLen = 0; - std::vector<unsigned int> indexVector(fullPhraseLen, 0); - auto curr = static_cast<const QueryTerm *> (children[currPhraseLen].get()); - bool exhausted( curr->evaluateHits(tmpHL).empty()); - for (; !exhausted; ) { - auto next = static_cast<const QueryTerm *>(children[currPhraseLen+1].get()); - unsigned int & currIndex = indexVector[currPhraseLen]; - unsigned int & nextIndex = indexVector[currPhraseLen+1]; - - const auto & currHit = curr->evaluateHits(tmpHL)[currIndex]; - size_t firstPosition = currHit.pos(); - uint32_t currElemId = currHit.elemId(); - uint32_t currContext = currHit.context(); - - const HitList & nextHL = next->evaluateHits(tmpHL); - - int diff(0); - size_t nextIndexMax = nextHL.size(); - while ((nextIndex < nextIndexMax) && - ((nextHL[nextIndex].context() < currContext) || - ((nextHL[nextIndex].context() == currContext) && (nextHL[nextIndex].elemId() <= currElemId))) && - ((diff = nextHL[nextIndex].pos()-firstPosition) < 1)) - { - nextIndex++; - } - if ((diff == 1) && (nextHL[nextIndex].context() == currContext) && (nextHL[nextIndex].elemId() == currElemId)) { - currPhraseLen++; - if ((currPhraseLen+1) == fullPhraseLen) { - Hit h = nextHL[indexVector[currPhraseLen]]; - hl.push_back(h); - const QueryTerm::FieldInfo & fi = next->getFieldInfo(h.context()); - updateFieldInfo(h.context(), hl.size() - 1, fi.getFieldLength()); - currPhraseLen = 0; - indexVector[0]++; - } - } else { - currPhraseLen = 0; - indexVector[currPhraseLen]++; +RankWithQueryNode::evaluate() const { + bool first = true; + bool firstOk = false; + for (const auto & qn : getChildren()) { + if (qn->evaluate()) { + if (first) firstOk = true; } - curr = static_cast<const QueryTerm *>(children[currPhraseLen].get()); - exhausted = (nextIndex >= nextIndexMax) || (indexVector[currPhraseLen] >= curr->evaluateHits(tmpHL).size()); - } - return hl; -} - -void -PhraseQueryNode::updateFieldInfo(size_t fid, size_t offset, size_t fieldLength) const -{ - if (fid >= _fieldInfo.size()) { - _fieldInfo.resize(fid + 1); - // only set hit offset and field length the first time - QueryTerm::FieldInfo & fi = _fieldInfo[fid]; - fi.setHitOffset(offset); - fi.setFieldLength(fieldLength); + first = false; } - QueryTerm::FieldInfo & fi = _fieldInfo[fid]; - fi.setHitCount(fi.getHitCount() + 1); -} - -bool -NearQueryNode::evaluate() const -{ - return AndQueryNode::evaluate(); -} - -void -NearQueryNode::visitMembers(vespalib::ObjectVisitor &visitor) const -{ - AndQueryNode::visitMembers(visitor); - visit(visitor, "distance", static_cast<uint64_t>(_distance)); + return firstOk; } - bool -ONearQueryNode::evaluate() const +EquivQueryNode::evaluate() const { - bool ok(NearQueryNode::evaluate()); - return ok; + return OrQueryNode::evaluate(); } Query::Query() = default; diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.h b/searchlib/src/vespa/searchlib/query/streaming/query.h index 84c693b86d0..5296d3a4f69 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/query.h +++ b/searchlib/src/vespa/searchlib/query/streaming/query.h @@ -76,7 +76,7 @@ class AndNotQueryNode : public QueryConnector public: AndNotQueryNode() noexcept : QueryConnector("ANDNOT") { } bool evaluate() const override; - bool isFlattenable(ParseItem::ItemType type) const override { return type == ParseItem::ITEM_NOT; } + bool isFlattenable(ParseItem::ItemType) const override { return false; } }; /** @@ -95,79 +95,28 @@ public: }; /** - N-ary "EQUIV" operator that merges terms from nodes below. -*/ -class EquivQueryNode : public OrQueryNode -{ -public: - EquivQueryNode() noexcept : OrQueryNode("EQUIV") { } - bool evaluate() const override; - bool isFlattenable(ParseItem::ItemType type) const override { - return (type == ParseItem::ITEM_EQUIV); - } -}; - -/** - N-ary phrase operator. All terms must be satisfied and have the correct order - with distance to next term equal to 1. + N-ary RankWith operator */ -class PhraseQueryNode : public AndQueryNode -{ -public: - PhraseQueryNode() noexcept : AndQueryNode("PHRASE"), _fieldInfo(32) { } - bool evaluate() const override; - const HitList & evaluateHits(HitList & hl) const override; - void getPhrases(QueryNodeRefList & tl) override; - void getPhrases(ConstQueryNodeRefList & tl) const override; - const QueryTerm::FieldInfo & getFieldInfo(size_t fid) const { return _fieldInfo[fid]; } - size_t getFieldInfoSize() const { return _fieldInfo.size(); } - bool isFlattenable(ParseItem::ItemType type) const override { return type == ParseItem::ITEM_NOT; } - void addChild(QueryNode::UP child) override; -private: - mutable std::vector<QueryTerm::FieldInfo> _fieldInfo; - void updateFieldInfo(size_t fid, size_t offset, size_t fieldLength) const; -#if WE_EVER_NEED_TO_CACHE_THIS_WE_MIGHT_WANT_SOME_CODE_HERE - HitList _cachedHitList; - bool _evaluated; -#endif -}; - -class SameElementQueryNode : public AndQueryNode +class RankWithQueryNode : public QueryConnector { public: - SameElementQueryNode() noexcept : AndQueryNode("SAME_ELEMENT") { } + RankWithQueryNode() noexcept : QueryConnector("RANK") { } + explicit RankWithQueryNode(const char * opName) noexcept : QueryConnector(opName) { } bool evaluate() const override; - const HitList & evaluateHits(HitList & hl) const override; - bool isFlattenable(ParseItem::ItemType type) const override { return type == ParseItem::ITEM_NOT; } - void addChild(QueryNode::UP child) override; }; -/** - N-ary Near operator. All terms must be within the given distance. -*/ -class NearQueryNode : public AndQueryNode -{ -public: - NearQueryNode() noexcept : AndQueryNode("NEAR"), _distance(0) { } - explicit NearQueryNode(const char * opName) noexcept : AndQueryNode(opName), _distance(0) { } - bool evaluate() const override; - void distance(size_t dist) { _distance = dist; } - size_t distance() const { return _distance; } - void visitMembers(vespalib::ObjectVisitor &visitor) const override; - bool isFlattenable(ParseItem::ItemType type) const override { return type == ParseItem::ITEM_NOT; } -private: - size_t _distance; -}; /** - N-ary Ordered near operator. The terms must be in order and the distance between - the first and last must not exceed the given distance. + N-ary "EQUIV" operator that merges terms from nodes below. */ -class ONearQueryNode : public NearQueryNode +class EquivQueryNode : public OrQueryNode { public: - ONearQueryNode() noexcept : NearQueryNode("ONEAR") { } + EquivQueryNode() noexcept : OrQueryNode("EQUIV") { } bool evaluate() const override; + bool isFlattenable(ParseItem::ItemType type) const override { + return (type == ParseItem::ITEM_EQUIV); + } }; /** diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp index 2ee515f062a..32e3ec16b16 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp @@ -1,8 +1,12 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "query.h" +#include "fuzzy_term.h" +#include "near_query_node.h" #include "nearest_neighbor_query_node.h" +#include "phrase_query_node.h" +#include "query.h" #include "regexp_term.h" +#include "same_element_query_node.h" #include <vespa/searchlib/parsequery/stackdumpiterator.h> #include <vespa/searchlib/query/streaming/dot_product_term.h> #include <vespa/searchlib/query/streaming/in_term.h> @@ -47,6 +51,7 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor case ParseItem::ITEM_SAME_ELEMENT: case ParseItem::ITEM_NEAR: case ParseItem::ITEM_ONEAR: + case ParseItem::ITEM_RANK: { qn = QueryConnector::create(type); if (qn) { @@ -147,17 +152,16 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor } else { Normalizing normalize_mode = factory.normalizing_mode(ssIndex); std::unique_ptr<QueryTerm> qt; - if (sTerm != TermType::REGEXP) { - qt = std::make_unique<QueryTerm>(factory.create(), ssTerm, ssIndex, sTerm, normalize_mode); - } else { + if (sTerm == TermType::REGEXP) { qt = std::make_unique<RegexpTerm>(factory.create(), ssTerm, ssIndex, TermType::REGEXP, normalize_mode); + } else if (sTerm == TermType::FUZZYTERM) { + qt = std::make_unique<FuzzyTerm>(factory.create(), ssTerm, ssIndex, TermType::FUZZYTERM, normalize_mode, + queryRep.getFuzzyMaxEditDistance(), queryRep.getFuzzyPrefixLength()); + } else [[likely]] { + qt = std::make_unique<QueryTerm>(factory.create(), ssTerm, ssIndex, sTerm, normalize_mode); } qt->setWeight(queryRep.GetWeight()); qt->setUniqueId(queryRep.getUniqueId()); - if (qt->isFuzzy()) { - qt->setFuzzyMaxEditDistance(queryRep.getFuzzyMaxEditDistance()); - qt->setFuzzyPrefixLength(queryRep.getFuzzyPrefixLength()); - } if (allowRewrite && possibleFloat(*qt, ssTerm) && factory.allow_float_terms_rewrite(ssIndex)) { auto phrase = std::make_unique<PhraseQueryNode>(); auto dotPos = ssTerm.find('.'); @@ -173,17 +177,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor } } break; - case ParseItem::ITEM_RANK: - { - if (arity >= 1) { - queryRep.next(); - qn = Build(parent, factory, queryRep, false); - for (uint32_t skipCount = arity-1; (skipCount > 0) && queryRep.next(); skipCount--) { - skipCount += queryRep.getArity(); - } - } - } - break; case ParseItem::ITEM_STRING_IN: qn = std::make_unique<InTerm>(factory.create(), queryRep.getIndexName(), queryRep.get_terms(), factory.normalizing_mode(queryRep.getIndexName())); diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp index d72a3371846..af8ce7c9994 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp @@ -1,5 +1,6 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "querynoderesultbase.h" +#include <ostream> namespace search::streaming { diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp index 3e05d381ee2..b7e619cfe4c 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp @@ -162,9 +162,18 @@ void QueryTerm::resizeFieldId(size_t fieldNo) } } -void QueryTerm::add(unsigned pos, unsigned context, uint32_t elemId, int32_t weight_) +uint32_t +QueryTerm::add(uint32_t field_id, uint32_t element_id, int32_t element_weight, uint32_t position) { - _hitList.emplace_back(pos, context, elemId, weight_); + uint32_t idx = _hitList.size(); + _hitList.emplace_back(field_id, element_id, element_weight, position); + return idx; +} + +void +QueryTerm::set_element_length(uint32_t hitlist_idx, uint32_t element_length) +{ + _hitList[hitlist_idx].set_element_length(element_length); } NearestNeighborQueryNode* @@ -185,4 +194,10 @@ QueryTerm::as_regexp_term() noexcept return nullptr; } +FuzzyTerm* +QueryTerm::as_fuzzy_term() noexcept +{ + return nullptr; +} + } diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h index cd2bdd7eaec..504b94de747 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h +++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h @@ -11,6 +11,7 @@ namespace search::streaming { +class FuzzyTerm; class NearestNeighborQueryNode; class MultiTerm; class RegexpTerm; @@ -64,7 +65,7 @@ public: QueryTerm & operator = (QueryTerm &&) = delete; ~QueryTerm() override; bool evaluate() const override; - const HitList & evaluateHits(HitList & hl) const override; + const HitList & evaluateHits(HitList & hl) const final override; void reset() override; void getLeaves(QueryTermList & tl) override; void getLeaves(ConstQueryTermList & tl) const override; @@ -73,7 +74,8 @@ public: /// Gives you all phrases of this tree. Indicating that they are all const. void getPhrases(ConstQueryNodeRefList & tl) const override; - void add(unsigned pos, unsigned context, uint32_t elemId, int32_t weight); + uint32_t add(uint32_t field_id, uint32_t element_id, int32_t element_weight, uint32_t position); + void set_element_length(uint32_t hitlist_idx, uint32_t element_length); EncodingBitMap encoding() const { return _encoding; } size_t termLen() const { return getTermLen(); } const string & index() const { return _index; } @@ -95,6 +97,7 @@ public: virtual NearestNeighborQueryNode* as_nearest_neighbor_query_node() noexcept; virtual MultiTerm* as_multi_term() noexcept; virtual RegexpTerm* as_regexp_term() noexcept; + virtual FuzzyTerm* as_fuzzy_term() noexcept; protected: using QueryNodeResultBaseContainer = std::unique_ptr<QueryNodeResultBase>; string _index; diff --git a/searchlib/src/vespa/searchlib/query/streaming/same_element_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/same_element_query_node.cpp new file mode 100644 index 00000000000..49d5fb0f9fb --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/same_element_query_node.cpp @@ -0,0 +1,65 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "same_element_query_node.h" +#include <cassert> + +namespace search::streaming { + +bool +SameElementQueryNode::evaluate() const { + HitList hl; + return ! evaluateHits(hl).empty(); +} + +void +SameElementQueryNode::addChild(QueryNode::UP child) { + assert(dynamic_cast<const QueryTerm *>(child.get()) != nullptr); + AndQueryNode::addChild(std::move(child)); +} + +const HitList & +SameElementQueryNode::evaluateHits(HitList & hl) const +{ + hl.clear(); + if ( !AndQueryNode::evaluate()) return hl; + + HitList tmpHL; + const auto & children = getChildren(); + unsigned int numFields = children.size(); + unsigned int currMatchCount = 0; + std::vector<unsigned int> indexVector(numFields, 0); + auto curr = static_cast<const QueryTerm *> (children[currMatchCount].get()); + bool exhausted( curr->evaluateHits(tmpHL).empty()); + for (; !exhausted; ) { + auto next = static_cast<const QueryTerm *>(children[currMatchCount+1].get()); + unsigned int & currIndex = indexVector[currMatchCount]; + unsigned int & nextIndex = indexVector[currMatchCount+1]; + + const auto & currHit = curr->evaluateHits(tmpHL)[currIndex]; + uint32_t currElemId = currHit.element_id(); + + const HitList & nextHL = next->evaluateHits(tmpHL); + + size_t nextIndexMax = nextHL.size(); + while ((nextIndex < nextIndexMax) && (nextHL[nextIndex].element_id() < currElemId)) { + nextIndex++; + } + if ((nextIndex < nextIndexMax) && (nextHL[nextIndex].element_id() == currElemId)) { + currMatchCount++; + if ((currMatchCount+1) == numFields) { + Hit h = nextHL[indexVector[currMatchCount]]; + hl.emplace_back(h.field_id(), h.element_id(), h.element_weight(), 0); + currMatchCount = 0; + indexVector[0]++; + } + } else { + currMatchCount = 0; + indexVector[currMatchCount]++; + } + curr = static_cast<const QueryTerm *>(children[currMatchCount].get()); + exhausted = (nextIndex >= nextIndexMax) || (indexVector[currMatchCount] >= curr->evaluateHits(tmpHL).size()); + } + return hl; +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/same_element_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/same_element_query_node.h new file mode 100644 index 00000000000..8e675feb569 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/same_element_query_node.h @@ -0,0 +1,22 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "query.h" + +namespace search::streaming { + +/** + N-ary Same element operator. All terms must be within the same element. +*/ +class SameElementQueryNode : public AndQueryNode +{ +public: + SameElementQueryNode() noexcept : AndQueryNode("SAME_ELEMENT") { } + bool evaluate() const override; + const HitList & evaluateHits(HitList & hl) const override; + bool isFlattenable(ParseItem::ItemType) const override { return false; } + void addChild(QueryNode::UP child) override; +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp index 90d0be5d43c..d2d706eef3d 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp @@ -25,7 +25,7 @@ WeightedSetTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchDat for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - scores[hit.context()].emplace_back(term->weight().percent()); + scores[hit.field_id()].emplace_back(term->weight().percent()); } } auto num_fields = td.numFields(); diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h index 4473e0fa45b..3d8a5fba843 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h +++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h @@ -10,7 +10,6 @@ namespace search::streaming { * A weighted set query term for streaming search. */ class WeightedSetTerm : public MultiTerm { - double _score_threshold; public: WeightedSetTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string& index, uint32_t num_terms); ~WeightedSetTerm() override; diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp index 6ca072d6dc7..2b9b45c990e 100644 --- a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp @@ -89,7 +89,6 @@ Blueprint::sat_sum(const std::vector<HitEstimate> &data, uint32_t docid_limit) Blueprint::State::State() noexcept : _fields(), - _relative_estimate(0.0), _estimateHits(0), _tree_size(1), _estimateEmpty(true), @@ -106,7 +105,6 @@ Blueprint::State::State(FieldSpecBase field) noexcept Blueprint::State::State(FieldSpecBaseList fields_in) noexcept : _fields(std::move(fields_in)), - _relative_estimate(0.0), _estimateHits(0), _tree_size(1), _estimateEmpty(true), @@ -120,7 +118,9 @@ Blueprint::State::~State() = default; Blueprint::Blueprint() noexcept : _parent(nullptr), - _cost(1.0), + _relative_estimate(0.0), + _cost(0.0), + _strict_cost(0.0), _sourceId(0xffffffff), _docid_limit(0), _frozen(false) @@ -130,15 +130,15 @@ Blueprint::Blueprint() noexcept Blueprint::~Blueprint() = default; Blueprint::UP -Blueprint::optimize(Blueprint::UP bp, bool sort_by_cost) { +Blueprint::optimize(Blueprint::UP bp) { Blueprint *root = bp.release(); - root->optimize(root, OptimizePass::FIRST, sort_by_cost); - root->optimize(root, OptimizePass::LAST, sort_by_cost); + root->optimize(root, OptimizePass::FIRST); + root->optimize(root, OptimizePass::LAST); return Blueprint::UP(root); } void -Blueprint::optimize_self(OptimizePass, bool) +Blueprint::optimize_self(OptimizePass) { } @@ -353,12 +353,13 @@ Blueprint::visitMembers(vespalib::ObjectVisitor &visitor) const visitor.openStruct("estimate", "HitEstimate"); visitor.visitBool("empty", state.estimate().empty); visitor.visitInt("estHits", state.estimate().estHits); - visitor.visitFloat("relative_estimate", state.relative_estimate()); visitor.visitInt("cost_tier", state.cost_tier()); visitor.visitInt("tree_size", state.tree_size()); visitor.visitBool("allow_termwise_eval", state.allow_termwise_eval()); visitor.closeStruct(); + visitor.visitFloat("relative_estimate", _relative_estimate); visitor.visitFloat("cost", _cost); + visitor.visitFloat("strict_cost", _strict_cost); visitor.visitInt("sourceId", _sourceId); visitor.visitInt("docid_limit", _docid_limit); } @@ -518,7 +519,6 @@ IntermediateBlueprint::calculateState() const { State state(exposeFields()); state.estimate(calculateEstimate()); - state.relative_estimate(calculate_relative_estimate()); state.cost_tier(calculate_cost_tier()); state.allow_termwise_eval(infer_allow_termwise_eval()); state.want_global_filter(infer_want_global_filter()); @@ -548,25 +548,33 @@ IntermediateBlueprint::should_do_termwise_eval(const UnpackInfo &unpack, double } void -IntermediateBlueprint::optimize(Blueprint* &self, OptimizePass pass, bool sort_by_cost) +IntermediateBlueprint::optimize(Blueprint* &self, OptimizePass pass) { assert(self == this); - if (should_optimize_children()) { - for (auto &child : _children) { - auto *child_ptr = child.release(); - child_ptr->optimize(child_ptr, pass, sort_by_cost); - child.reset(child_ptr); - } + for (auto &child : _children) { + auto *child_ptr = child.release(); + child_ptr->optimize(child_ptr, pass); + child.reset(child_ptr); } - optimize_self(pass, sort_by_cost); + optimize_self(pass); if (pass == OptimizePass::LAST) { - sort(_children, sort_by_cost); + set_relative_estimate(calculate_relative_estimate()); set_cost(calculate_cost()); + set_strict_cost(calculate_strict_cost()); } maybe_eliminate_self(self, get_replacement()); } void +IntermediateBlueprint::sort(bool strict, bool sort_by_cost) +{ + sort(_children, strict, sort_by_cost); + for (size_t i = 0; i < _children.size(); ++i) { + _children[i]->sort(strict && inheritStrict(i), sort_by_cost); + } +} + +void IntermediateBlueprint::set_global_filter(const GlobalFilter &global_filter, double estimated_hit_ratio) { for (auto & child : _children) { @@ -710,23 +718,32 @@ IntermediateBlueprint::calculateUnpackInfo(const fef::MatchData & md) const //----------------------------------------------------------------------------- -void -LeafBlueprint::setDocIdLimit(uint32_t limit) noexcept { - Blueprint::setDocIdLimit(limit); - _state.relative_estimate(calculate_relative_estimate()); - notifyChange(); -} - double LeafBlueprint::calculate_relative_estimate() const { double rel_est = abs_to_rel_est(_state.estimate().estHits, get_docid_limit()); if (rel_est > 0.9) { // Assume we do not really know how much we are matching when - // we claim to match 'everything' + // we claim to match 'everything'. Also assume we are not able + // to skip documents efficiently when strict. + _can_skip = false; return 0.5; + } else { + _can_skip = true; + return rel_est; } - return rel_est; +} + +double +LeafBlueprint::calculate_cost() const +{ + return 1.0; +} + +double +LeafBlueprint::calculate_strict_cost() const +{ + return _can_skip ? estimate() * cost() : cost(); } void @@ -758,14 +775,24 @@ LeafBlueprint::getRange(vespalib::string &, vespalib::string &) const { } void -LeafBlueprint::optimize(Blueprint* &self, OptimizePass pass, bool sort_by_cost) +LeafBlueprint::optimize(Blueprint* &self, OptimizePass pass) { assert(self == this); - optimize_self(pass, sort_by_cost); + optimize_self(pass); + if (pass == OptimizePass::LAST) { + set_relative_estimate(calculate_relative_estimate()); + set_cost(calculate_cost()); + set_strict_cost(calculate_strict_cost()); + } maybe_eliminate_self(self, get_replacement()); } void +LeafBlueprint::sort(bool, bool) +{ +} + +void LeafBlueprint::set_cost_tier(uint32_t value) { assert(value < 0x100); diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.h b/searchlib/src/vespa/searchlib/queryeval/blueprint.h index d998c2e343e..e080d667dfa 100644 --- a/searchlib/src/vespa/searchlib/queryeval/blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.h @@ -53,7 +53,7 @@ public: using SearchIteratorUP = std::unique_ptr<SearchIterator>; enum class OptimizePass { FIRST, LAST }; - + struct HitEstimate { uint32_t estHits; bool empty; @@ -75,7 +75,6 @@ public: { private: FieldSpecBaseList _fields; - double _relative_estimate; uint32_t _estimateHits; uint32_t _tree_size : 20; bool _estimateEmpty : 1; @@ -111,9 +110,6 @@ public: return nullptr; } - void relative_estimate(double value) noexcept { _relative_estimate = value; } - double relative_estimate() const noexcept { return _relative_estimate; } - void estimate(HitEstimate est) noexcept { _estimateHits = est.estHits; _estimateEmpty = est.empty; @@ -182,7 +178,9 @@ public: private: Blueprint *_parent; + double _relative_estimate; double _cost; + double _strict_cost; uint32_t _sourceId; uint32_t _docid_limit; bool _frozen; @@ -198,7 +196,9 @@ protected: _frozen = true; } + void set_relative_estimate(double value) noexcept { _relative_estimate = value; } void set_cost(double value) noexcept { _cost = value; } + void set_strict_cost(double value) noexcept { _strict_cost = value; } public: class IPredicate { @@ -222,19 +222,22 @@ public: Blueprint *getParent() const noexcept { return _parent; } bool has_parent() const { return (_parent != nullptr); } - double cost() const noexcept { return _cost; } - Blueprint &setSourceId(uint32_t sourceId) noexcept { _sourceId = sourceId; return *this; } uint32_t getSourceId() const noexcept { return _sourceId; } virtual void setDocIdLimit(uint32_t limit) noexcept { _docid_limit = limit; } uint32_t get_docid_limit() const noexcept { return _docid_limit; } - static Blueprint::UP optimize(Blueprint::UP bp, bool sort_by_cost); - virtual void optimize(Blueprint* &self, OptimizePass pass, bool sort_by_cost) = 0; - virtual void optimize_self(OptimizePass pass, bool sort_by_cost); + static Blueprint::UP optimize(Blueprint::UP bp); + virtual void sort(bool strict, bool sort_by_cost) = 0; + static Blueprint::UP optimize_and_sort(Blueprint::UP bp, bool strict, bool sort_by_cost) { + auto result = optimize(std::move(bp)); + result->sort(strict, sort_by_cost); + return result; + } + virtual void optimize(Blueprint* &self, OptimizePass pass) = 0; + virtual void optimize_self(OptimizePass pass); virtual Blueprint::UP get_replacement(); - virtual bool should_optimize_children() const { return true; } virtual bool supports_termwise_children() const { return false; } virtual bool always_needs_unpack() const { return false; } @@ -254,9 +257,27 @@ public: const Blueprint &root() const; double hit_ratio() const { return getState().hit_ratio(_docid_limit); } - double estimate() const { return getState().relative_estimate(); } - virtual double calculate_relative_estimate() const = 0; + // The flow statistics for a blueprint is calculated during the + // LAST optimize pass (just prior to sorting). The relative + // estimate may be used to calculate the costs and the non-strict + // cost may be used to calculate the strict cost. After being + // calculated, each value is available through a simple accessor + // function. Note that these values may not be available for + // blueprints used inside complex leafs (this case will probably + // be solved using custom flow adapters that has knowledge of + // docid limit). + // + // 'estimate': relative estimate in the range [0,1] + // 'cost': per-document cost of non-strict evaluation + // 'strict_cost': per-document cost of strict evaluation + double estimate() const noexcept { return _relative_estimate; } + double cost() const noexcept { return _cost; } + double strict_cost() const noexcept { return _strict_cost; } + virtual double calculate_relative_estimate() const = 0; + virtual double calculate_cost() const = 0; + virtual double calculate_strict_cost() const = 0; + virtual void fetchPostings(const ExecuteInfo &execInfo) = 0; virtual void freeze() = 0; bool frozen() const { return _frozen; } @@ -360,7 +381,8 @@ public: void setDocIdLimit(uint32_t limit) noexcept final; - void optimize(Blueprint* &self, OptimizePass pass, bool sort_by_cost) final; + void optimize(Blueprint* &self, OptimizePass pass) final; + void sort(bool strict, bool sort_by_cost) override; void set_global_filter(const GlobalFilter &global_filter, double estimated_hit_ratio) override; IndexList find(const IPredicate & check) const; @@ -374,10 +396,9 @@ public: Blueprint::UP removeLastChild() { return removeChild(childCnt() - 1); } SearchIteratorUP createSearch(fef::MatchData &md, bool strict) const override; - virtual double calculate_cost() const = 0; virtual HitEstimate combine(const std::vector<HitEstimate> &data) const = 0; virtual FieldSpecBaseList exposeFields() const = 0; - virtual void sort(Children &children, bool sort_by_cost) const = 0; + virtual void sort(Children &children, bool strict, bool sort_by_cost) const = 0; virtual bool inheritStrict(size_t i) const = 0; virtual SearchIteratorUP createIntermediateSearch(MultiSearch::Children subSearches, @@ -396,11 +417,12 @@ class LeafBlueprint : public Blueprint { private: State _state; + mutable bool _can_skip = true; protected: - void optimize(Blueprint* &self, OptimizePass pass, bool sort_by_cost) final; + void optimize(Blueprint* &self, OptimizePass pass) final; + void sort(bool strict, bool sort_by_cost) override; void setEstimate(HitEstimate est) { _state.estimate(est); - _state.relative_estimate(calculate_relative_estimate()); notifyChange(); } void set_cost_tier(uint32_t value); @@ -431,9 +453,9 @@ protected: public: ~LeafBlueprint() override = default; const State &getState() const final { return _state; } - void setDocIdLimit(uint32_t limit) noexcept final; - using Blueprint::set_cost; double calculate_relative_estimate() const override; + double calculate_cost() const override; + double calculate_strict_cost() const override; void fetchPostings(const ExecuteInfo &execInfo) override; void freeze() final; SearchIteratorUP createSearch(fef::MatchData &md, bool strict) const override; diff --git a/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h b/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h index e49fcbcb5bc..c74c2d4e9a7 100644 --- a/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h +++ b/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h @@ -27,6 +27,7 @@ protected: public: static constexpr bool filter_search = false; static constexpr bool require_btree_iterators = true; + static constexpr bool supports_hash_filter = false; // TODO: use MultiSearch::Children to pass ownership static SearchIterator::UP create(const std::vector<SearchIterator*> &children, diff --git a/searchlib/src/vespa/searchlib/queryeval/equivsearch.h b/searchlib/src/vespa/searchlib/queryeval/equivsearch.h index f14f6d56308..343e27ceccc 100644 --- a/searchlib/src/vespa/searchlib/queryeval/equivsearch.h +++ b/searchlib/src/vespa/searchlib/queryeval/equivsearch.h @@ -13,7 +13,7 @@ namespace search::queryeval { /** * A simple implementation of the Equiv search operation. **/ -class EquivSearch : public SearchIterator +class EquivSearch { public: using Children = MultiSearch::Children; diff --git a/searchlib/src/vespa/searchlib/queryeval/flow.h b/searchlib/src/vespa/searchlib/queryeval/flow.h index 86ce6f8b93b..b90321581b5 100644 --- a/searchlib/src/vespa/searchlib/queryeval/flow.h +++ b/searchlib/src/vespa/searchlib/queryeval/flow.h @@ -13,19 +13,11 @@ namespace search::queryeval { namespace flow { // the default adapter expects the shape of std::unique_ptr<Blueprint> -// with respect to estimate, cost and (coming soon) strict_cost. +// with respect to estimate, cost and strict_cost. struct DefaultAdapter { double estimate(const auto &child) const noexcept { return child->estimate(); } double cost(const auto &child) const noexcept { return child->cost(); } - // Estimate the per-document cost of strict evaluation of this - // child. This will typically be something like (estimate() * - // cost()) for leafs with posting lists. OR will aggregate strict - // cost by calculating the minimal OR flow of strict child - // costs. AND will aggregate strict cost by calculating the - // minimal AND flow where the cost of the first child is - // substituted by its strict cost. This value is currently not - // available in Blueprints. - double strict_cost(const auto &child) const noexcept { return child->cost(); } + double strict_cost(const auto &child) const noexcept { return child->strict_cost(); } }; template <typename ADAPTER, typename T> @@ -154,8 +146,6 @@ struct FlowMixin { static double cost_of(const auto &children, bool strict) { return cost_of(flow::DefaultAdapter(), children, strict); } - // TODO: remove - static double cost_of(const auto &children) { return cost_of(children, false); } }; class AndFlow : public FlowMixin<AndFlow> { @@ -190,9 +180,8 @@ public: children[0] = std::move(the_one); } } - // TODO: add strict - static void sort(auto &children) { - sort(flow::DefaultAdapter(), children, false); + static void sort(auto &children, bool strict) { + sort(flow::DefaultAdapter(), children, strict); } }; @@ -224,9 +213,8 @@ public: flow::sort<flow::MinOrCost>(adapter, children); } } - // TODO: add strict - static void sort(auto &children) { - sort(flow::DefaultAdapter(), children, false); + static void sort(auto &children, bool strict) { + sort(flow::DefaultAdapter(), children, strict); } }; @@ -254,9 +242,8 @@ public: static void sort(auto adapter, auto &children, bool) { flow::sort_partial<flow::MinOrCost>(adapter, children, 1); } - // TODO: add strict - static void sort(auto &children) { - sort(flow::DefaultAdapter(), children, false); + static void sort(auto &children, bool strict) { + sort(flow::DefaultAdapter(), children, strict); } }; diff --git a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp index e60fe3d3f85..8cabe189b0e 100644 --- a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp @@ -33,7 +33,7 @@ size_t lookup_create_source(std::vector<std::unique_ptr<CombineType> > &sources, } template <typename CombineType> -void optimize_source_blenders(IntermediateBlueprint &self, size_t begin_idx, bool sort_by_cost) { +void optimize_source_blenders(IntermediateBlueprint &self, size_t begin_idx) { std::vector<size_t> source_blenders; const SourceBlenderBlueprint * reference = nullptr; for (size_t i = begin_idx; i < self.childCnt(); ++i) { @@ -63,7 +63,7 @@ void optimize_source_blenders(IntermediateBlueprint &self, size_t begin_idx, boo top->addChild(std::move(sources.back())); sources.pop_back(); } - blender_up = Blueprint::optimize(std::move(blender_up), sort_by_cost); + blender_up = Blueprint::optimize(std::move(blender_up)); self.addChild(std::move(blender_up)); } } @@ -87,15 +87,21 @@ need_normal_features_for_children(const IntermediateBlueprint &blueprint, fef::M //----------------------------------------------------------------------------- double +AndNotBlueprint::calculate_relative_estimate() const +{ + return AndNotFlow::estimate_of(get_children()); +} + +double AndNotBlueprint::calculate_cost() const { - return AndNotFlow::cost_of(get_children()); + return AndNotFlow::cost_of(get_children(), false); } double -AndNotBlueprint::calculate_relative_estimate() const +AndNotBlueprint::calculate_strict_cost() const { - return AndNotFlow::estimate_of(get_children()); + return AndNotFlow::cost_of(get_children(), true); } Blueprint::HitEstimate @@ -114,7 +120,7 @@ AndNotBlueprint::exposeFields() const } void -AndNotBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) +AndNotBlueprint::optimize_self(OptimizePass pass) { if (childCnt() == 0) { return; @@ -152,7 +158,7 @@ AndNotBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) } } if (pass == OptimizePass::LAST) { - optimize_source_blenders<OrBlueprint>(*this, 1, sort_by_cost); + optimize_source_blenders<OrBlueprint>(*this, 1); } } @@ -166,10 +172,10 @@ AndNotBlueprint::get_replacement() } void -AndNotBlueprint::sort(Children &children, bool sort_by_cost) const +AndNotBlueprint::sort(Children &children, bool strict, bool sort_by_cost) const { if (sort_by_cost) { - AndNotFlow::sort(children); + AndNotFlow::sort(children, strict); } else { if (children.size() > 2) { std::sort(children.begin() + 1, children.end(), TieredGreaterEstimate()); @@ -213,13 +219,18 @@ AndNotBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) co //----------------------------------------------------------------------------- double +AndBlueprint::calculate_relative_estimate() const { + return AndFlow::estimate_of(get_children()); +} + +double AndBlueprint::calculate_cost() const { - return AndFlow::cost_of(get_children()); + return AndFlow::cost_of(get_children(), false); } double -AndBlueprint::calculate_relative_estimate() const { - return AndFlow::estimate_of(get_children()); +AndBlueprint::calculate_strict_cost() const { + return AndFlow::cost_of(get_children(), true); } Blueprint::HitEstimate @@ -235,7 +246,7 @@ AndBlueprint::exposeFields() const } void -AndBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) +AndBlueprint::optimize_self(OptimizePass pass) { if (pass == OptimizePass::FIRST) { for (size_t i = 0; i < childCnt(); ++i) { @@ -248,7 +259,7 @@ AndBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) } } if (pass == OptimizePass::LAST) { - optimize_source_blenders<AndBlueprint>(*this, 0, sort_by_cost); + optimize_source_blenders<AndBlueprint>(*this, 0); } } @@ -262,10 +273,10 @@ AndBlueprint::get_replacement() } void -AndBlueprint::sort(Children &children, bool sort_by_cost) const +AndBlueprint::sort(Children &children, bool strict, bool sort_by_cost) const { if (sort_by_cost) { - AndFlow::sort(children); + AndFlow::sort(children, strict); } else { std::sort(children.begin(), children.end(), TieredLessEstimate()); } @@ -322,13 +333,18 @@ OrBlueprint::computeNextHitRate(const Blueprint & child, double hit_rate) const OrBlueprint::~OrBlueprint() = default; double +OrBlueprint::calculate_relative_estimate() const { + return OrFlow::estimate_of(get_children()); +} + +double OrBlueprint::calculate_cost() const { - return OrFlow::cost_of(get_children()); + return OrFlow::cost_of(get_children(), false); } double -OrBlueprint::calculate_relative_estimate() const { - return OrFlow::estimate_of(get_children()); +OrBlueprint::calculate_strict_cost() const { + return OrFlow::cost_of(get_children(), true); } Blueprint::HitEstimate @@ -344,7 +360,7 @@ OrBlueprint::exposeFields() const } void -OrBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) +OrBlueprint::optimize_self(OptimizePass pass) { if (pass == OptimizePass::FIRST) { for (size_t i = 0; (childCnt() > 1) && (i < childCnt()); ++i) { @@ -359,7 +375,7 @@ OrBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) } } if (pass == OptimizePass::LAST) { - optimize_source_blenders<OrBlueprint>(*this, 0, sort_by_cost); + optimize_source_blenders<OrBlueprint>(*this, 0); } } @@ -373,10 +389,10 @@ OrBlueprint::get_replacement() } void -OrBlueprint::sort(Children &children, bool sort_by_cost) const +OrBlueprint::sort(Children &children, bool strict, bool sort_by_cost) const { if (sort_by_cost) { - OrFlow::sort(children); + OrFlow::sort(children, strict); } else { std::sort(children.begin(), children.end(), TieredGreaterEstimate()); } @@ -427,17 +443,22 @@ OrBlueprint::calculate_cost_tier() const WeakAndBlueprint::~WeakAndBlueprint() = default; double -WeakAndBlueprint::calculate_cost() const { - return OrFlow::cost_of(get_children()); -} - -double WeakAndBlueprint::calculate_relative_estimate() const { double child_est = OrFlow::estimate_of(get_children()); double my_est = abs_to_rel_est(_n, get_docid_limit()); return std::min(my_est, child_est); } +double +WeakAndBlueprint::calculate_cost() const { + return OrFlow::cost_of(get_children(), false); +} + +double +WeakAndBlueprint::calculate_strict_cost() const { + return OrFlow::cost_of(get_children(), true); +} + Blueprint::HitEstimate WeakAndBlueprint::combine(const std::vector<HitEstimate> &data) const { @@ -456,7 +477,7 @@ WeakAndBlueprint::exposeFields() const } void -WeakAndBlueprint::sort(Children &, bool) const +WeakAndBlueprint::sort(Children &, bool, bool) const { // order needs to stay the same as _weights } @@ -498,13 +519,18 @@ WeakAndBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) c //----------------------------------------------------------------------------- double +NearBlueprint::calculate_relative_estimate() const { + return AndFlow::estimate_of(get_children()); +} + +double NearBlueprint::calculate_cost() const { - return AndFlow::cost_of(get_children()) + childCnt() * 1.0; + return AndFlow::cost_of(get_children(), false) + childCnt() * estimate(); } double -NearBlueprint::calculate_relative_estimate() const { - return AndFlow::estimate_of(get_children()); +NearBlueprint::calculate_strict_cost() const { + return AndFlow::cost_of(get_children(), true) + childCnt() * estimate(); } Blueprint::HitEstimate @@ -520,10 +546,10 @@ NearBlueprint::exposeFields() const } void -NearBlueprint::sort(Children &children, bool sort_by_cost) const +NearBlueprint::sort(Children &children, bool strict, bool sort_by_cost) const { if (sort_by_cost) { - AndFlow::sort(children); + AndFlow::sort(children, strict); } else { std::sort(children.begin(), children.end(), TieredLessEstimate()); } @@ -565,13 +591,18 @@ NearBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) cons //----------------------------------------------------------------------------- double +ONearBlueprint::calculate_relative_estimate() const { + return AndFlow::estimate_of(get_children()); +} + +double ONearBlueprint::calculate_cost() const { - return AndFlow::cost_of(get_children()) + (childCnt() * 1.0); + return AndFlow::cost_of(get_children(), false) + childCnt() * estimate(); } double -ONearBlueprint::calculate_relative_estimate() const { - return AndFlow::estimate_of(get_children()); +ONearBlueprint::calculate_strict_cost() const { + return AndFlow::cost_of(get_children(), true) + childCnt() * estimate(); } Blueprint::HitEstimate @@ -587,7 +618,7 @@ ONearBlueprint::exposeFields() const } void -ONearBlueprint::sort(Children &, bool) const +ONearBlueprint::sort(Children &, bool, bool) const { // ordered near cannot sort children here } @@ -630,13 +661,18 @@ ONearBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) con //----------------------------------------------------------------------------- double +RankBlueprint::calculate_relative_estimate() const { + return (childCnt() == 0) ? 0.0 : getChild(0).estimate(); +} + +double RankBlueprint::calculate_cost() const { - return (childCnt() == 0) ? 1.0 : getChild(0).cost(); + return (childCnt() == 0) ? 0.0 : getChild(0).cost(); } double -RankBlueprint::calculate_relative_estimate() const { - return (childCnt() == 0) ? 0.0 : getChild(0).estimate(); +RankBlueprint::calculate_strict_cost() const { + return (childCnt() == 0) ? 0.0 : getChild(0).strict_cost(); } Blueprint::HitEstimate @@ -655,7 +691,7 @@ RankBlueprint::exposeFields() const } void -RankBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) +RankBlueprint::optimize_self(OptimizePass pass) { if (pass == OptimizePass::FIRST) { for (size_t i = 1; i < childCnt(); ++i) { @@ -665,7 +701,7 @@ RankBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) } } if (pass == OptimizePass::LAST) { - optimize_source_blenders<OrBlueprint>(*this, 1, sort_by_cost); + optimize_source_blenders<OrBlueprint>(*this, 1); } } @@ -679,7 +715,7 @@ RankBlueprint::get_replacement() } void -RankBlueprint::sort(Children &, bool) const +RankBlueprint::sort(Children &, bool, bool) const { } @@ -731,8 +767,13 @@ SourceBlenderBlueprint::SourceBlenderBlueprint(const ISourceSelector &selector) SourceBlenderBlueprint::~SourceBlenderBlueprint() = default; double +SourceBlenderBlueprint::calculate_relative_estimate() const { + return OrFlow::estimate_of(get_children()); +} + +double SourceBlenderBlueprint::calculate_cost() const { - double my_cost = 1.0; + double my_cost = 0.0; for (const auto &child: get_children()) { my_cost = std::max(my_cost, child->cost()); } @@ -740,8 +781,12 @@ SourceBlenderBlueprint::calculate_cost() const { } double -SourceBlenderBlueprint::calculate_relative_estimate() const { - return OrFlow::estimate_of(get_children()); +SourceBlenderBlueprint::calculate_strict_cost() const { + double my_cost = 0.0; + for (const auto &child: get_children()) { + my_cost = std::max(my_cost, child->strict_cost()); + } + return my_cost; } Blueprint::HitEstimate @@ -757,7 +802,7 @@ SourceBlenderBlueprint::exposeFields() const } void -SourceBlenderBlueprint::sort(Children &, bool) const +SourceBlenderBlueprint::sort(Children &, bool, bool) const { } diff --git a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h index 620280e979b..368cbd35c69 100644 --- a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h +++ b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h @@ -15,14 +15,15 @@ class AndNotBlueprint : public IntermediateBlueprint { public: bool supports_termwise_children() const override { return true; } - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void optimize_self(OptimizePass pass, bool sort_by_cost) override; + void optimize_self(OptimizePass pass) override; AndNotBlueprint * asAndNot() noexcept final { return this; } Blueprint::UP get_replacement() override; - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, @@ -43,14 +44,15 @@ class AndBlueprint : public IntermediateBlueprint { public: bool supports_termwise_children() const override { return true; } - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void optimize_self(OptimizePass pass, bool sort_by_cost) override; + void optimize_self(OptimizePass pass) override; AndBlueprint * asAnd() noexcept final { return this; } Blueprint::UP get_replacement() override; - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, @@ -69,14 +71,15 @@ class OrBlueprint : public IntermediateBlueprint public: ~OrBlueprint() override; bool supports_termwise_children() const override { return true; } - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void optimize_self(OptimizePass pass, bool sort_by_cost) override; + void optimize_self(OptimizePass pass) override; OrBlueprint * asOr() noexcept final { return this; } Blueprint::UP get_replacement() override; - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, @@ -97,11 +100,12 @@ private: std::vector<uint32_t> _weights; public: - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_on_cost) const override; bool inheritStrict(size_t i) const override; bool always_needs_unpack() const override; WeakAndBlueprint * asWeakAnd() noexcept final { return this; } @@ -128,12 +132,12 @@ private: uint32_t _window; public: - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - bool should_optimize_children() const override { return false; } - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; SearchIteratorUP createSearch(fef::MatchData &md, bool strict) const override; SearchIterator::UP @@ -152,12 +156,12 @@ private: uint32_t _window; public: - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - bool should_optimize_children() const override { return false; } - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; SearchIteratorUP createSearch(fef::MatchData &md, bool strict) const override; SearchIterator::UP @@ -173,13 +177,14 @@ public: class RankBlueprint final : public IntermediateBlueprint { public: - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void optimize_self(OptimizePass pass, bool sort_by_cost) override; + void optimize_self(OptimizePass pass) override; Blueprint::UP get_replacement() override; - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; bool isRank() const noexcept final { return true; } SearchIterator::UP @@ -202,11 +207,12 @@ private: public: explicit SourceBlenderBlueprint(const ISourceSelector &selector) noexcept; ~SourceBlenderBlueprint() override; - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, diff --git a/searchlib/src/vespa/searchlib/queryeval/nearsearch.cpp b/searchlib/src/vespa/searchlib/queryeval/nearsearch.cpp index 1f83075b9fc..8fc7733f279 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearsearch.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearsearch.cpp @@ -3,7 +3,7 @@ #include <vespa/vespalib/objects/visit.h> #include <vespa/vespalib/util/priority_queue.h> #include <limits> -#include <set> +#include <map> #include <vespa/log/log.h> LOG_SETUP(".nearsearch"); @@ -16,13 +16,15 @@ using search::fef::TermFieldMatchDataArray; using search::fef::TermFieldMatchDataPositionKey; template<typename T> -void setup_fields(uint32_t window, std::vector<T> &matchers, const TermFieldMatchDataArray &in) { - std::set<uint32_t> fields; +void setup_fields(uint32_t window, std::vector<T> &matchers, const TermFieldMatchDataArray &in, uint32_t terms) { + std::map<uint32_t,uint32_t> fields; for (size_t i = 0; i < in.size(); ++i) { - fields.insert(in[i]->getFieldId()); + ++fields[in[i]->getFieldId()]; } - for (const auto& elem : fields) { - matchers.push_back(T(window, elem, in)); + for (auto [field, cnt]: fields) { + if (cnt == terms) { + matchers.push_back(T(window, field, in)); + } } } @@ -126,7 +128,7 @@ NearSearch::NearSearch(Children terms, : NearSearchBase(std::move(terms), data, window, strict), _matchers() { - setup_fields(window, _matchers, data); + setup_fields(window, _matchers, data, getChildren().size()); } namespace { @@ -227,7 +229,7 @@ ONearSearch::ONearSearch(Children terms, : NearSearchBase(std::move(terms), data, window, strict), _matchers() { - setup_fields(window, _matchers, data); + setup_fields(window, _matchers, data, getChildren().size()); } bool diff --git a/searchlib/src/vespa/searchlib/queryeval/orlikesearch.h b/searchlib/src/vespa/searchlib/queryeval/orlikesearch.h index 9c67d2c6a01..a15a87c2d03 100644 --- a/searchlib/src/vespa/searchlib/queryeval/orlikesearch.h +++ b/searchlib/src/vespa/searchlib/queryeval/orlikesearch.h @@ -66,5 +66,83 @@ private: Unpack _unpacker; }; +template <typename Unpack, typename HEAP, typename ref_t> +class StrictHeapOrSearch final : public OrSearch +{ +private: + struct Less { + const uint32_t *child_docid; + constexpr explicit Less(const std::vector<uint32_t> &cd) noexcept : child_docid(cd.data()) {} + constexpr bool operator()(const ref_t &a, const ref_t &b) const noexcept { + return (child_docid[a] < child_docid[b]); + } + }; + + std::vector<ref_t> _data; + std::vector<uint32_t> _child_docid; + Unpack _unpacker; + + void init_data() { + _data.resize(getChildren().size()); + for (size_t i = 0; i < getChildren().size(); ++i) { + _data[i] = i; + } + } + void onRemove(size_t index) override { + _unpacker.onRemove(index); + _child_docid.erase(_child_docid.begin() + index); + init_data(); + } + void onInsert(size_t index) override { + _unpacker.onInsert(index); + _child_docid.insert(_child_docid.begin() + index, getChildren()[index]->getDocId()); + init_data(); + } + void seek_child(ref_t child, uint32_t docid) { + getChildren()[child]->doSeek(docid); + _child_docid[child] = getChildren()[child]->getDocId(); + } + ref_t *data_begin() noexcept { return _data.data(); } + ref_t *data_pos(size_t offset) noexcept { return _data.data() + offset; } + ref_t *data_end() noexcept { return _data.data() + _data.size(); } + +public: + StrictHeapOrSearch(Children children, const Unpack &unpacker) + : OrSearch(std::move(children)), + _data(), + _child_docid(getChildren().size()), + _unpacker(unpacker) + { + HEAP::require_left_heap(); + init_data(); + } + void initRange(uint32_t begin, uint32_t end) override { + OrSearch::initRange(begin, end); + for (size_t i = 0; i < getChildren().size(); ++i) { + _child_docid[i] = getChildren()[i]->getDocId(); + } + for (size_t i = 2; i <= _data.size(); ++i) { + HEAP::push(data_begin(), data_pos(i), Less(_child_docid)); + } + } + void doSeek(uint32_t docid) override { + while (_child_docid[HEAP::front(data_begin(), data_end())] < docid) { + seek_child(HEAP::front(data_begin(), data_end()), docid); + HEAP::adjust(data_begin(), data_end(), Less(_child_docid)); + } + setDocId(_child_docid[HEAP::front(data_begin(), data_end())]); + } + void doUnpack(uint32_t docid) override { + _unpacker.each([&](ref_t child) { + if (__builtin_expect(_child_docid[child] == docid, false)) { + getChildren()[child]->doUnpack(docid); + } + }, getChildren().size()); + } + bool needUnpack(size_t index) const override { + return _unpacker.needUnpack(index); + } + Trinary is_strict() const override { return Trinary::True; } +}; } diff --git a/searchlib/src/vespa/searchlib/queryeval/orsearch.cpp b/searchlib/src/vespa/searchlib/queryeval/orsearch.cpp index 7cdb13fc159..29ec8632612 100644 --- a/searchlib/src/vespa/searchlib/queryeval/orsearch.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/orsearch.cpp @@ -4,11 +4,15 @@ #include "orlikesearch.h" #include "termwise_helper.h" #include <vespa/searchlib/common/bitvector.h> +#include <vespa/vespalib/util/left_right_heap.h> namespace search::queryeval { namespace { +using vespalib::LeftArrayHeap; +using vespalib::LeftHeap; + class FullUnpack { public: @@ -24,6 +28,11 @@ public: } } } + void each(auto &&f, size_t n) { + for (size_t i = 0; i < n; ++i) { + f(i); + } + } void onRemove(size_t index) { (void) index; } void onInsert(size_t index) { (void) index; } bool needUnpack(size_t index) const { (void) index; return true; } @@ -47,6 +56,9 @@ public: } }, children.size()); } + void each(auto &&f, size_t n) { + _unpackInfo.each(std::forward<decltype(f)>(f), n); + } void onRemove(size_t index) { _unpackInfo.remove(index); } @@ -60,6 +72,22 @@ private: UnpackInfo _unpackInfo; }; +template <typename Unpack> +SearchIterator::UP create_strict_or(std::vector<SearchIterator::UP> children, const Unpack &unpack, OrSearch::StrictImpl strict_impl) { + if (strict_impl == OrSearch::StrictImpl::HEAP) { + if (children.size() <= 0x70) { + return std::make_unique<StrictHeapOrSearch<Unpack,LeftArrayHeap,uint8_t>>(std::move(children), unpack); + } else if (children.size() <= 0xff) { + return std::make_unique<StrictHeapOrSearch<Unpack,LeftHeap,uint8_t>>(std::move(children), unpack); + } else if (children.size() <= 0xffff) { + return std::make_unique<StrictHeapOrSearch<Unpack,LeftHeap,uint16_t>>(std::move(children), unpack); + } else { + return std::make_unique<StrictHeapOrSearch<Unpack,LeftHeap,uint32_t>>(std::move(children), unpack); + } + } + return std::make_unique<OrLikeSearch<true,Unpack>>(std::move(children), unpack); +} + } BitVector::UP @@ -82,21 +110,23 @@ SearchIterator::UP OrSearch::create(ChildrenIterators children, bool strict) { UnpackInfo unpackInfo; unpackInfo.forceAll(); - return create(std::move(children), strict, unpackInfo); + return create(std::move(children), strict, unpackInfo, StrictImpl::PLAIN); } SearchIterator::UP OrSearch::create(ChildrenIterators children, bool strict, const UnpackInfo & unpackInfo) { + return create(std::move(children), strict, unpackInfo, StrictImpl::PLAIN); +} + +SearchIterator::UP +OrSearch::create(ChildrenIterators children, bool strict, const UnpackInfo & unpackInfo, StrictImpl strict_impl) { if (strict) { if (unpackInfo.unpackAll()) { - using MyOr = OrLikeSearch<true, FullUnpack>; - return std::make_unique<MyOr>(std::move(children), FullUnpack()); + return create_strict_or(std::move(children), FullUnpack(), strict_impl); } else if(unpackInfo.empty()) { - using MyOr = OrLikeSearch<true, NoUnpack>; - return std::make_unique<MyOr>(std::move(children), NoUnpack()); + return create_strict_or(std::move(children), NoUnpack(), strict_impl); } else { - using MyOr = OrLikeSearch<true, SelectiveUnpack>; - return std::make_unique<MyOr>(std::move(children), SelectiveUnpack(unpackInfo)); + return create_strict_or(std::move(children), SelectiveUnpack(unpackInfo), strict_impl); } } else { if (unpackInfo.unpackAll()) { diff --git a/searchlib/src/vespa/searchlib/queryeval/orsearch.h b/searchlib/src/vespa/searchlib/queryeval/orsearch.h index d56fb0e99b4..02ee80f1bd8 100644 --- a/searchlib/src/vespa/searchlib/queryeval/orsearch.h +++ b/searchlib/src/vespa/searchlib/queryeval/orsearch.h @@ -15,8 +15,12 @@ class OrSearch : public MultiSearch public: using Children = MultiSearch::Children; + enum class StrictImpl { PLAIN, HEAP }; + static SearchIterator::UP create(ChildrenIterators children, bool strict); static SearchIterator::UP create(ChildrenIterators children, bool strict, const UnpackInfo & unpackInfo); + static SearchIterator::UP create(ChildrenIterators children, bool strict, const UnpackInfo & unpackInfo, + StrictImpl strict_impl); std::unique_ptr<BitVector> get_hits(uint32_t begin_id) override; void or_hits_into(BitVector &result, uint32_t begin_id) override; diff --git a/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.cpp index 500e9fe4dbb..96181377282 100644 --- a/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.cpp @@ -45,7 +45,7 @@ SameElementBlueprint::addTerm(Blueprint::UP term) } void -SameElementBlueprint::optimize_self(OptimizePass pass, bool) +SameElementBlueprint::optimize_self(OptimizePass pass) { if (pass == OptimizePass::LAST) { std::sort(_terms.begin(), _terms.end(), diff --git a/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.h index 06c20339e81..6a988e67149 100644 --- a/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.h @@ -34,7 +34,7 @@ public: // used by create visitor void addTerm(Blueprint::UP term); - void optimize_self(OptimizePass pass, bool sort_by_cost) override; + void optimize_self(OptimizePass pass) override; void fetchPostings(const ExecuteInfo &execInfo) override; std::unique_ptr<SameElementSearch> create_same_element_search(search::fef::TermFieldMatchData& tfmd, bool strict) const; diff --git a/searchlib/src/vespa/searchlib/queryeval/simple_phrase_search.cpp b/searchlib/src/vespa/searchlib/queryeval/simple_phrase_search.cpp index 0bbdf89bab7..2b25aa29747 100644 --- a/searchlib/src/vespa/searchlib/queryeval/simple_phrase_search.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/simple_phrase_search.cpp @@ -8,7 +8,6 @@ using search::fef::TermFieldMatchData; using std::unique_ptr; -using std::transform; using std::vector; using vespalib::ObjectVisitor; diff --git a/searchlib/src/vespa/searchlib/queryeval/unpackinfo.h b/searchlib/src/vespa/searchlib/queryeval/unpackinfo.h index 0ec8d07e19e..64efd66b8c0 100644 --- a/searchlib/src/vespa/searchlib/queryeval/unpackinfo.h +++ b/searchlib/src/vespa/searchlib/queryeval/unpackinfo.h @@ -59,6 +59,7 @@ struct NoUnpack { (void) docid; (void) search; } + void each(auto &&f, size_t n) { (void) f; (void) n; } void onRemove(size_t index) { (void) index; } void onInsert(size_t index) { (void) index; } bool needUnpack(size_t index) const { (void) index; return false; } diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp index 0929f80a8f0..0be014474d0 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp @@ -1,10 +1,13 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "weighted_set_term_search.h" +#include <vespa/searchcommon/attribute/i_search_context.h> +#include <vespa/searchcommon/attribute/iattributevector.h> +#include <vespa/searchlib/attribute/i_direct_posting_store.h> +#include <vespa/searchlib/attribute/multi_term_hash_filter.hpp> #include <vespa/searchlib/common/bitvector.h> -#include <vespa/searchlib/attribute/multi_term_or_filter_search.h> #include <vespa/vespalib/objects/visit.h> -#include <vespa/searchcommon/attribute/i_search_context.h> +#include <vespa/vespalib/stllike/hash_map.hpp> #include "iterator_pack.h" #include "blueprint.h" @@ -14,7 +17,13 @@ using vespalib::ObjectVisitor; namespace search::queryeval { -template <typename HEAP, typename IteratorPack> +enum class UnpackType { + DocidAndWeights, + Docid, + None +}; + +template <UnpackType unpack_type, typename HEAP, typename IteratorPack> class WeightedSetTermSearchImpl : public WeightedSetTermSearch { private: @@ -47,7 +56,6 @@ private: ref_t *_data_stash; ref_t *_data_end; IteratorPack _children; - bool _need_match_data; void seek_child(ref_t child, uint32_t docId) { _termPos[child] = _children.seek(child, docId); @@ -64,7 +72,6 @@ private: public: WeightedSetTermSearchImpl(fef::TermFieldMatchData &tmd, - bool field_is_filter, std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>> weights, IteratorPack &&iteratorPack) : _tmd(tmd), @@ -77,8 +84,7 @@ public: _data_begin(nullptr), _data_stash(nullptr), _data_end(nullptr), - _children(std::move(iteratorPack)), - _need_match_data(!field_is_filter && !_tmd.isNotNeeded()) + _children(std::move(iteratorPack)) { HEAP::require_left_heap(); assert(_children.size() > 0); @@ -89,7 +95,7 @@ public: } _data_begin = &_data_space[0]; _data_end = _data_begin + _data_space.size(); - if (_need_match_data) { + if constexpr (unpack_type == UnpackType::DocidAndWeights) { _tmd.reservePositions(_children.size()); } } @@ -115,7 +121,7 @@ public: } void doUnpack(uint32_t docId) override { - if (_need_match_data) { + if constexpr (unpack_type == UnpackType::DocidAndWeights) { _tmd.reset(docId); pop_matching_children(docId); std::sort(_data_stash, _data_end, _cmpWeight); @@ -124,7 +130,7 @@ public: pos.setElementWeight(_weights[*ptr]); _tmd.appendPosition(pos); } - } else { + } else if constexpr (unpack_type == UnpackType::Docid) { _tmd.resetOnlyDocId(docId); } } @@ -162,68 +168,170 @@ public: } }; -//----------------------------------------------------------------------------- +template <typename HeapType, typename IteratorPackType> +SearchIterator::UP +create_helper(fef::TermFieldMatchData& tmd, + bool is_filter_search, + std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>> weights, + IteratorPackType&& pack) +{ + bool match_data_needed = !tmd.isNotNeeded(); + if (is_filter_search && match_data_needed) { + return std::make_unique<WeightedSetTermSearchImpl<UnpackType::Docid, HeapType, IteratorPackType>> + (tmd, std::move(weights), std::move(pack)); + } else if (!is_filter_search && match_data_needed) { + return std::make_unique<WeightedSetTermSearchImpl<UnpackType::DocidAndWeights, HeapType, IteratorPackType>> + (tmd, std::move(weights), std::move(pack)); + } else { + return std::make_unique<WeightedSetTermSearchImpl<UnpackType::None, HeapType, IteratorPackType>> + (tmd, std::move(weights), std::move(pack)); + } +} SearchIterator::UP WeightedSetTermSearch::create(const std::vector<SearchIterator *> &children, TermFieldMatchData &tmd, - bool field_is_filter, + bool is_filter_search, const std::vector<int32_t> &weights, fef::MatchData::UP match_data) { - using ArrayHeapImpl = WeightedSetTermSearchImpl<vespalib::LeftArrayHeap, SearchIteratorPack>; - using HeapImpl = WeightedSetTermSearchImpl<vespalib::LeftHeap, SearchIteratorPack>; - - if (tmd.isNotNeeded()) { - return attribute::MultiTermOrFilterSearch::create(children, std::move(match_data)); - } - if (children.size() < 128) { - return SearchIterator::UP(new ArrayHeapImpl(tmd, field_is_filter, std::cref(weights), SearchIteratorPack(children, std::move(match_data)))); + return create_helper<vespalib::LeftArrayHeap, SearchIteratorPack>(tmd, is_filter_search, std::cref(weights), + SearchIteratorPack(children, std::move(match_data))); } - return SearchIterator::UP(new HeapImpl(tmd, field_is_filter, std::cref(weights), SearchIteratorPack(children, std::move(match_data)))); + return create_helper<vespalib::LeftHeap, SearchIteratorPack>(tmd, is_filter_search, std::cref(weights), + SearchIteratorPack(children, std::move(match_data))); } -//----------------------------------------------------------------------------- - namespace { template <typename IteratorType, typename IteratorPackType> SearchIterator::UP -create_helper(fef::TermFieldMatchData& tmd, - bool field_is_filter, - std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>> weights, - std::vector<IteratorType>&& iterators) +create_helper_resolve_pack(fef::TermFieldMatchData& tmd, + bool is_filter_search, + std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>> weights, + std::vector<IteratorType>&& iterators) { - using ArrayHeapImpl = WeightedSetTermSearchImpl<vespalib::LeftArrayHeap, IteratorPackType>; - using HeapImpl = WeightedSetTermSearchImpl<vespalib::LeftHeap, IteratorPackType>; - if (iterators.size() < 128) { - return SearchIterator::UP(new ArrayHeapImpl(tmd, field_is_filter, std::move(weights), IteratorPackType(std::move(iterators)))); + return create_helper<vespalib::LeftArrayHeap, IteratorPackType>(tmd, is_filter_search, std::move(weights), + IteratorPackType(std::move(iterators))); } - return SearchIterator::UP(new HeapImpl(tmd, field_is_filter, std::move(weights), IteratorPackType(std::move(iterators)))); + return create_helper<vespalib::LeftHeap, IteratorPackType>(tmd, is_filter_search, std::move(weights), + IteratorPackType(std::move(iterators))); } } SearchIterator::UP WeightedSetTermSearch::create(fef::TermFieldMatchData& tmd, - bool field_is_filter, + bool is_filter_search, std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>> weights, std::vector<DocidIterator>&& iterators) { - return create_helper<DocidIterator, DocidIteratorPack>(tmd, field_is_filter, std::move(weights), std::move(iterators)); + return create_helper_resolve_pack<DocidIterator, DocidIteratorPack>(tmd, is_filter_search, std::move(weights), std::move(iterators)); } SearchIterator::UP WeightedSetTermSearch::create(fef::TermFieldMatchData &tmd, - bool field_is_filter, + bool is_filter_search, std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>> weights, std::vector<DocidWithWeightIterator> &&iterators) { - return create_helper<DocidWithWeightIterator, DocidWithWeightIteratorPack>(tmd, field_is_filter, std::move(weights), std::move(iterators)); + return create_helper_resolve_pack<DocidWithWeightIterator, DocidWithWeightIteratorPack>(tmd, is_filter_search, std::move(weights), std::move(iterators)); } -//----------------------------------------------------------------------------- +namespace { + +class HashFilterWrapper { +protected: + const attribute::IAttributeVector& _attr; +public: + HashFilterWrapper(const attribute::IAttributeVector& attr) : _attr(attr) {} +}; + +template <bool unpack_weights_t> +class StringHashFilterWrapper : public HashFilterWrapper { +public: + using TokenT = attribute::IAttributeVector::EnumHandle; + static constexpr bool unpack_weights = unpack_weights_t; + StringHashFilterWrapper(const attribute::IAttributeVector& attr) + : HashFilterWrapper(attr) + {} + auto mapToken(const IDirectPostingStore::LookupResult& term, const IDirectPostingStore& store, vespalib::datastore::EntryRef dict_snapshot) const { + std::vector<TokenT> result; + store.collect_folded(term.enum_idx, dict_snapshot, [&](vespalib::datastore::EntryRef ref) { result.emplace_back(ref.ref()); }); + return result; + } + TokenT getToken(uint32_t docid) const { + return _attr.getEnum(docid); + } +}; + +template <bool unpack_weights_t> +class IntegerHashFilterWrapper : public HashFilterWrapper { +public: + using TokenT = attribute::IAttributeVector::largeint_t; + static constexpr bool unpack_weights = unpack_weights_t; + IntegerHashFilterWrapper(const attribute::IAttributeVector& attr) + : HashFilterWrapper(attr) + {} + auto mapToken(const IDirectPostingStore::LookupResult& term, + const IDirectPostingStore& store, + vespalib::datastore::EntryRef) const { + std::vector<TokenT> result; + result.emplace_back(store.get_integer_value(term.enum_idx)); + return result; + } + TokenT getToken(uint32_t docid) const { + return _attr.getInt(docid); + } +}; + +template <typename WrapperType> +SearchIterator::UP +create_hash_filter_helper(fef::TermFieldMatchData& tfmd, + const std::vector<int32_t>& weights, + const std::vector<IDirectPostingStore::LookupResult>& terms, + const attribute::IAttributeVector& attr, + const IDirectPostingStore& posting_store, + vespalib::datastore::EntryRef dict_snapshot) +{ + using FilterType = attribute::MultiTermHashFilter<WrapperType>; + typename FilterType::TokenMap tokens; + WrapperType wrapper(attr); + for (size_t i = 0; i < terms.size(); ++i) { + for (auto token : wrapper.mapToken(terms[i], posting_store, dict_snapshot)) { + tokens[token] = weights[i]; + } + } + return std::make_unique<FilterType>(tfmd, wrapper, std::move(tokens)); +} + +} + +SearchIterator::UP +WeightedSetTermSearch::create_hash_filter(search::fef::TermFieldMatchData& tmd, + bool is_filter_search, + const std::vector<int32_t>& weights, + const std::vector<IDirectPostingStore::LookupResult>& terms, + const attribute::IAttributeVector& attr, + const IDirectPostingStore& posting_store, + vespalib::datastore::EntryRef dict_snapshot) +{ + if (attr.isStringType()) { + if (is_filter_search) { + return create_hash_filter_helper<StringHashFilterWrapper<false>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } else { + return create_hash_filter_helper<StringHashFilterWrapper<true>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } + } else { + assert(attr.isIntegerType()); + if (is_filter_search) { + return create_hash_filter_helper<IntegerHashFilterWrapper<false>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } else { + return create_hash_filter_helper<IntegerHashFilterWrapper<true>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } + } +} } diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h index a9ab86e2c5f..d078fd5babc 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h @@ -11,6 +11,8 @@ #include <variant> #include <vector> +namespace search::attribute { class IAttributeVector; } + namespace search::fef { class TermFieldMatchData; } namespace search::queryeval { @@ -18,8 +20,7 @@ namespace search::queryeval { class Blueprint; /** - * Search iterator for a weighted set, based on a set of child search - * iterators. + * Search iterator for a WeightedSetTerm, based on a set of child search iterators. */ class WeightedSetTermSearch : public SearchIterator { @@ -27,26 +28,38 @@ protected: WeightedSetTermSearch() = default; public: + // Whether this iterator is considered a filter, independent of attribute vector settings (ref rank: filter). static constexpr bool filter_search = false; + // Whether this iterator requires btree iterators for all tokens/terms used by the operator. static constexpr bool require_btree_iterators = false; + // Whether this supports creating a hash filter iterator; + static constexpr bool supports_hash_filter = true; // TODO: pass ownership with unique_ptr static SearchIterator::UP create(const std::vector<SearchIterator *> &children, search::fef::TermFieldMatchData &tmd, - bool field_is_filter, + bool is_filter_search, const std::vector<int32_t> &weights, fef::MatchData::UP match_data); static SearchIterator::UP create(search::fef::TermFieldMatchData& tmd, - bool field_is_filter, + bool is_filter_search, std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>> weights, std::vector<DocidIterator>&& iterators); static SearchIterator::UP create(search::fef::TermFieldMatchData &tmd, - bool field_is_filter, + bool is_filter_search, std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>> weights, std::vector<DocidWithWeightIterator> &&iterators); + static SearchIterator::UP create_hash_filter(search::fef::TermFieldMatchData& tmd, + bool is_filter_search, + const std::vector<int32_t>& weights, + const std::vector<IDirectPostingStore::LookupResult>& terms, + const attribute::IAttributeVector& attr, + const IDirectPostingStore& posting_store, + vespalib::datastore::EntryRef dictionary_snapshot); + // used during docsum fetching to identify matching elements // initRange must be called before use. // doSeek/doUnpack must not be called. diff --git a/slobrok/src/vespa/slobrok/server/rpchooks.cpp b/slobrok/src/vespa/slobrok/server/rpchooks.cpp index 851a556b15e..032b6a9af8f 100644 --- a/slobrok/src/vespa/slobrok/server/rpchooks.cpp +++ b/slobrok/src/vespa/slobrok/server/rpchooks.cpp @@ -77,14 +77,6 @@ RPCHooks::RPCHooks(SBEnv &env) RPCHooks::~RPCHooks() = default; void RPCHooks::reportMetrics() { - EV_COUNT("heartbeats_failed", _cnts.heartBeatFails); - EV_COUNT("register_reqs", _cnts.registerReqs); - EV_COUNT("mirror_reqs", _cnts.mirrorReqs); - EV_COUNT("wantadd_reqs", _cnts.wantAddReqs); - EV_COUNT("doadd_reqs", _cnts.doAddReqs); - EV_COUNT("doremove_reqs", _cnts.doRemoveReqs); - EV_COUNT("admin_reqs", _cnts.adminReqs); - EV_COUNT("other_reqs", _cnts.otherReqs); } void RPCHooks::initRPC(FRT_Supervisor *supervisor) { diff --git a/storage/src/tests/bucketdb/bucketmanagertest.cpp b/storage/src/tests/bucketdb/bucketmanagertest.cpp index 91e901c7254..45d8fab7061 100644 --- a/storage/src/tests/bucketdb/bucketmanagertest.cpp +++ b/storage/src/tests/bucketdb/bucketmanagertest.cpp @@ -453,7 +453,8 @@ TEST_F(BucketManagerTest, metrics_are_tracked_per_bucket_space) { auto& repo = _node->getComponentRegister().getBucketSpaceRepo(); { bucketdb::StorageBucketInfo entry; - api::BucketInfo info(50, 100, 200); + // checksum, doc count, doc size, meta count, total bucket size (incl meta) + api::BucketInfo info(50, 100, 200, 101, 211); info.setReady(true); entry.setBucketInfo(info); repo.get(document::FixedBucketSpaces::default_space()).bucketDatabase() @@ -461,7 +462,7 @@ TEST_F(BucketManagerTest, metrics_are_tracked_per_bucket_space) { } { bucketdb::StorageBucketInfo entry; - api::BucketInfo info(60, 150, 300); + api::BucketInfo info(60, 150, 300, 153, 307); info.setActive(true); entry.setBucketInfo(info); repo.get(document::FixedBucketSpaces::global_space()).bucketDatabase() @@ -475,6 +476,7 @@ TEST_F(BucketManagerTest, metrics_are_tracked_per_bucket_space) { auto default_m = spaces.find(document::FixedBucketSpaces::default_space()); ASSERT_TRUE(default_m != spaces.end()); EXPECT_EQ(1, default_m->second->buckets_total.getLast()); + EXPECT_EQ(101, default_m->second->entries.getLast()); EXPECT_EQ(100, default_m->second->docs.getLast()); EXPECT_EQ(200, default_m->second->bytes.getLast()); EXPECT_EQ(0, default_m->second->active_buckets.getLast()); @@ -485,6 +487,7 @@ TEST_F(BucketManagerTest, metrics_are_tracked_per_bucket_space) { auto global_m = spaces.find(document::FixedBucketSpaces::global_space()); ASSERT_TRUE(global_m != spaces.end()); EXPECT_EQ(1, global_m->second->buckets_total.getLast()); + EXPECT_EQ(153, global_m->second->entries.getLast()); EXPECT_EQ(150, global_m->second->docs.getLast()); EXPECT_EQ(300, global_m->second->bytes.getLast()); EXPECT_EQ(1, global_m->second->active_buckets.getLast()); @@ -499,7 +502,11 @@ TEST_F(BucketManagerTest, metrics_are_tracked_per_bucket_space) { jsonStream << End(); EXPECT_EQ(std::string("{\"values\":[" "{\"name\":\"vds.datastored.bucket_space.buckets_total\",\"values\":{\"last\":1},\"dimensions\":{\"bucketSpace\":\"global\"}}," + "{\"name\":\"vds.datastored.bucket_space.entries\",\"values\":{\"last\":153},\"dimensions\":{\"bucketSpace\":\"global\"}}," + "{\"name\":\"vds.datastored.bucket_space.docs\",\"values\":{\"last\":150},\"dimensions\":{\"bucketSpace\":\"global\"}}," "{\"name\":\"vds.datastored.bucket_space.buckets_total\",\"values\":{\"last\":1},\"dimensions\":{\"bucketSpace\":\"default\"}}," + "{\"name\":\"vds.datastored.bucket_space.entries\",\"values\":{\"last\":101},\"dimensions\":{\"bucketSpace\":\"default\"}}," + "{\"name\":\"vds.datastored.bucket_space.docs\",\"values\":{\"last\":100},\"dimensions\":{\"bucketSpace\":\"default\"}}," "{\"name\":\"vds.datastored.alldisks.docs\",\"values\":{\"last\":250}}," "{\"name\":\"vds.datastored.alldisks.bytes\",\"values\":{\"last\":500}}," "{\"name\":\"vds.datastored.alldisks.buckets\",\"values\":{\"last\":2}}" diff --git a/storage/src/tests/common/testhelper.cpp b/storage/src/tests/common/testhelper.cpp index 7b8af42fd84..4ca935b7904 100644 --- a/storage/src/tests/common/testhelper.cpp +++ b/storage/src/tests/common/testhelper.cpp @@ -59,15 +59,12 @@ vdstestlib::DirConfig getStandardConfig(bool storagenode, const std::string & ro config = &dc.addConfig("stor-communicationmanager"); config->set("rpcport", "0"); config->set("mbusport", "0"); - config = &dc.addConfig("stor-bucketdb"); - config->set("chunklevel", "0"); config = &dc.addConfig("stor-distributormanager"); config->set("splitcount", "1000"); config->set("splitsize", "10000000"); config->set("joincount", "500"); config->set("joinsize", "5000000"); config->set("max_clock_skew_sec", "0"); - config = &dc.addConfig("stor-opslogger"); config = &dc.addConfig("persistence"); config->set("abort_operations_with_changed_bucket_ownership", "true"); config = &dc.addConfig("stor-filestor"); diff --git a/storage/src/tests/distributor/distributor_stripe_test.cpp b/storage/src/tests/distributor/distributor_stripe_test.cpp index 92cd3898886..a10e4ee6a0e 100644 --- a/storage/src/tests/distributor/distributor_stripe_test.cpp +++ b/storage/src/tests/distributor/distributor_stripe_test.cpp @@ -57,7 +57,7 @@ struct DistributorStripeTest : Test, DistributorStripeTestUtil { return _stripe->_bucketDBMetricUpdater.getMinimumReplicaCountingMode(); } - std::string testOp(std::shared_ptr<api::StorageMessage> msg) { + std::string testOp(const std::shared_ptr<api::StorageMessage> & msg) { _stripe->handleMessage(msg); std::string tmp = _sender.getCommands(); @@ -83,8 +83,8 @@ struct DistributorStripeTest : Test, DistributorStripeTestUtil { std::vector<BucketCopy> changedNodes; vespalib::StringTokenizer tokenizer(states[i], ","); - for (uint32_t j = 0; j < tokenizer.size(); ++j) { - vespalib::StringTokenizer tokenizer2(tokenizer[j], ":"); + for (auto token : tokenizer) { + vespalib::StringTokenizer tokenizer2(token, ":"); bool trusted = false; if (tokenizer2.size() > 2) { @@ -96,14 +96,7 @@ struct DistributorStripeTest : Test, DistributorStripeTestUtil { removedNodes.push_back(node); } else { uint32_t checksum = atoi(tokenizer2[1].data()); - changedNodes.push_back( - BucketCopy( - i + 1, - node, - api::BucketInfo( - checksum, - checksum / 2, - checksum / 4)).setTrusted(trusted)); + changedNodes.emplace_back(i + 1, node, api::BucketInfo(checksum, checksum / 2, checksum / 4)).setTrusted(trusted); } } @@ -112,9 +105,7 @@ struct DistributorStripeTest : Test, DistributorStripeTestUtil { uint32_t flags(DatabaseUpdate::CREATE_IF_NONEXISTING | (resetTrusted ? DatabaseUpdate::RESET_TRUSTED : 0)); - operation_context().update_bucket_database(makeDocumentBucket(document::BucketId(16, 1)), - changedNodes, - flags); + operation_context().update_bucket_database(makeDocumentBucket(document::BucketId(16, 1)), changedNodes, flags); } std::string retVal = dumpBucket(document::BucketId(16, 1)); @@ -122,8 +113,8 @@ struct DistributorStripeTest : Test, DistributorStripeTestUtil { return retVal; } - void assertBucketSpaceStats(size_t expBucketPending, size_t expBucketTotal, uint16_t node, const vespalib::string& bucketSpace, - const BucketSpacesStatsProvider::PerNodeBucketSpacesStats& stats); + static void assertBucketSpaceStats(size_t expBucketPending, size_t expBucketTotal, uint16_t node, const vespalib::string& bucketSpace, + const BucketSpacesStatsProvider::PerNodeBucketSpacesStats& stats); SimpleMaintenanceScanner::PendingMaintenanceStats stripe_maintenance_stats() { return _stripe->pending_maintenance_stats(); @@ -175,12 +166,6 @@ struct DistributorStripeTest : Test, DistributorStripeTestUtil { }); } - void configure_prioritize_global_bucket_merges(bool enabled) { - configure_stripe_with([&](auto& builder) { - builder.prioritizeGlobalBucketMerges = enabled; - }); - } - void configure_max_activation_inhibited_out_of_sync_groups(uint32_t n_groups) { configure_stripe_with([&](auto& builder) { builder.maxActivationInhibitedOutOfSyncGroups = n_groups; @@ -471,43 +456,6 @@ TEST_F(DistributorStripeTest, update_bucket_database) updateBucketDB("0:456", "2:333", ResetTrusted(true))); } -TEST_F(DistributorStripeTest, priority_config_is_propagated_to_distributor_configuration) -{ - using namespace vespa::config::content::core; - - setup_stripe(Redundancy(2), NodeCount(2), "storage:2 distributor:1"); - - ConfigBuilder builder; - builder.priorityMergeMoveToIdealNode = 1; - builder.priorityMergeOutOfSyncCopies = 2; - builder.priorityMergeTooFewCopies = 3; - builder.priorityActivateNoExistingActive = 4; - builder.priorityActivateWithExistingActive = 5; - builder.priorityDeleteBucketCopy = 6; - builder.priorityJoinBuckets = 7; - builder.prioritySplitDistributionBits = 8; - builder.prioritySplitLargeBucket = 9; - builder.prioritySplitInconsistentBucket = 10; - builder.priorityGarbageCollection = 11; - builder.priorityMergeGlobalBuckets = 12; - - configure_stripe(builder); - - const auto& mp = getConfig().getMaintenancePriorities(); - EXPECT_EQ(1, static_cast<int>(mp.mergeMoveToIdealNode)); - EXPECT_EQ(2, static_cast<int>(mp.mergeOutOfSyncCopies)); - EXPECT_EQ(3, static_cast<int>(mp.mergeTooFewCopies)); - EXPECT_EQ(4, static_cast<int>(mp.activateNoExistingActive)); - EXPECT_EQ(5, static_cast<int>(mp.activateWithExistingActive)); - EXPECT_EQ(6, static_cast<int>(mp.deleteBucketCopy)); - EXPECT_EQ(7, static_cast<int>(mp.joinBuckets)); - EXPECT_EQ(8, static_cast<int>(mp.splitDistributionBits)); - EXPECT_EQ(9, static_cast<int>(mp.splitLargeBucket)); - EXPECT_EQ(10, static_cast<int>(mp.splitInconsistentBucket)); - EXPECT_EQ(11, static_cast<int>(mp.garbageCollection)); - EXPECT_EQ(12, static_cast<int>(mp.mergeGlobalBuckets)); -} - TEST_F(DistributorStripeTest, no_db_resurrection_for_bucket_not_owned_in_pending_state) { setup_stripe(Redundancy(1), NodeCount(10), "storage:2 distributor:2"); // Force new state into being the pending state. According to the initial @@ -969,17 +917,6 @@ TEST_F(DistributorStripeTest, weak_internal_read_consistency_config_is_propagate EXPECT_FALSE(getExternalOperationHandler().use_weak_internal_read_consistency_for_gets()); } -TEST_F(DistributorStripeTest, prioritize_global_bucket_merges_config_is_propagated_to_internal_config) -{ - setup_stripe(Redundancy(1), NodeCount(1), "distributor:1 storage:1"); - - configure_prioritize_global_bucket_merges(true); - EXPECT_TRUE(getConfig().prioritize_global_bucket_merges()); - - configure_prioritize_global_bucket_merges(false); - EXPECT_FALSE(getConfig().prioritize_global_bucket_merges()); -} - TEST_F(DistributorStripeTest, max_activation_inhibited_out_of_sync_groups_config_is_propagated_to_internal_config) { setup_stripe(Redundancy(1), NodeCount(1), "distributor:1 storage:1"); diff --git a/storage/src/tests/distributor/idealstatemanagertest.cpp b/storage/src/tests/distributor/idealstatemanagertest.cpp index fbcc188a5da..0cadaa3fc9f 100644 --- a/storage/src/tests/distributor/idealstatemanagertest.cpp +++ b/storage/src/tests/distributor/idealstatemanagertest.cpp @@ -94,33 +94,6 @@ TEST_F(IdealStateManagerTest, status_page) { ost.str()); } -TEST_F(IdealStateManagerTest, disabled_state_checker) { - setup_stripe(1, 1, "distributor:1 storage:1"); - - auto cfg = make_config(); - cfg->setSplitSize(100); - cfg->setSplitCount(1000000); - cfg->disableStateChecker("SplitBucket"); - configure_stripe(cfg); - - insertBucketInfo(document::BucketId(16, 5), 0, 0xff, 100, 200, true, true); - insertBucketInfo(document::BucketId(16, 2), 0, 0xff, 10, 10, true, true); - - std::ostringstream ost; - getIdealStateManager().getBucketStatus(ost); - - EXPECT_EQ(makeBucketStatusString( - "BucketId(0x4000000000000002) : [node(idx=0,crc=0xff,docs=10/10,bytes=10/10,trusted=true,active=true,ready=false)]<br>\n" - "<b>BucketId(0x4000000000000005):</b> <i> : split: [Splitting bucket because its maximum size (200 b, 100 docs, 100 meta, 200 b total) is " - "higher than the configured limit of (100, 1000000)]</i> [node(idx=0,crc=0xff,docs=100/100,bytes=200/200,trusted=true," - "active=true,ready=false)]<br>\n"), - ost.str()); - - tick(); - EXPECT_EQ("", active_ideal_state_operations()); - -} - TEST_F(IdealStateManagerTest, clear_active_on_node_down) { setSystemState(lib::ClusterState("distributor:1 storage:3")); for (int i = 1; i < 4; i++) { diff --git a/storage/src/tests/distributor/statecheckerstest.cpp b/storage/src/tests/distributor/statecheckerstest.cpp index 0f48440b5a1..a0d45292c1d 100644 --- a/storage/src/tests/distributor/statecheckerstest.cpp +++ b/storage/src/tests/distributor/statecheckerstest.cpp @@ -16,7 +16,6 @@ #include <vespa/storageapi/message/stat.h> #include <vespa/vdslib/distribution/distribution.h> #include <vespa/vespalib/gtest/gtest.h> -#include <gmock/gmock.h> using document::test::makeBucketSpace; using document::test::makeDocumentBucket; @@ -175,7 +174,6 @@ struct StateCheckersTest : Test, DistributorStripeTestUtil { bool _includeMessagePriority {false}; bool _includeSchedulingPriority {false}; bool _merge_operations_disabled {false}; - bool _prioritize_global_bucket_merges {true}; bool _config_enable_default_space_merge_inhibition {false}; bool _merges_inhibited_in_bucket_space {false}; CheckerParams(); @@ -217,10 +215,6 @@ struct StateCheckersTest : Test, DistributorStripeTestUtil { _merge_operations_disabled = disabled; return *this; } - CheckerParams& prioritize_global_bucket_merges(bool enabled) noexcept { - _prioritize_global_bucket_merges = enabled; - return *this; - } CheckerParams& bucket_space(document::BucketSpace bucket_space) noexcept { _bucket_space = bucket_space; return *this; @@ -246,7 +240,6 @@ struct StateCheckersTest : Test, DistributorStripeTestUtil { enable_cluster_state(params._clusterState); vespa::config::content::core::StorDistributormanagerConfigBuilder config; config.mergeOperationsDisabled = params._merge_operations_disabled; - config.prioritizeGlobalBucketMerges = params._prioritize_global_bucket_merges; config.inhibitDefaultMergesWhenGlobalMergesPending = params._config_enable_default_space_merge_inhibition; configure_stripe(config); if (!params._pending_cluster_state.empty()) { @@ -734,7 +727,7 @@ TEST_F(StateCheckersTest, synchronize_and_move) { .clusterState("distributor:1 storage:4")); } -TEST_F(StateCheckersTest, global_bucket_merges_have_very_high_priority_if_prioritization_enabled) { +TEST_F(StateCheckersTest, global_bucket_merges_have_very_high_priority) { runAndVerify<SynchronizeAndMoveStateChecker>( CheckerParams().expect( "[Synchronizing buckets with different checksums " @@ -745,23 +738,7 @@ TEST_F(StateCheckersTest, global_bucket_merges_have_very_high_priority_if_priori .bucketInfo("0=1,1=2") .bucket_space(document::FixedBucketSpaces::global_space()) .includeSchedulingPriority(true) - .includeMessagePriority(true) - .prioritize_global_bucket_merges(true)); -} - -TEST_F(StateCheckersTest, global_bucket_merges_have_normal_priority_if_prioritization_disabled) { - runAndVerify<SynchronizeAndMoveStateChecker>( - CheckerParams().expect( - "[Synchronizing buckets with different checksums " - "node(idx=0,crc=0x1,docs=1/1,bytes=1/1,trusted=false,active=false,ready=false), " - "node(idx=1,crc=0x2,docs=2/2,bytes=2/2,trusted=false,active=false,ready=false)] " - "(pri 120) " - "(scheduling pri MEDIUM)") - .bucketInfo("0=1,1=2") - .bucket_space(document::FixedBucketSpaces::global_space()) - .includeSchedulingPriority(true) - .includeMessagePriority(true) - .prioritize_global_bucket_merges(false)); + .includeMessagePriority(true)); } // Upon entering a cluster state transition edge the distributor will diff --git a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp index 4911ad88692..4846c90397a 100644 --- a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp +++ b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp @@ -5,6 +5,7 @@ #include <tests/common/teststorageapp.h> #include <tests/persistence/filestorage/forwardingmessagesender.h> #include <vespa/config/common/exceptions.h> +#include <memory> #include <vespa/config/helper/configgetter.hpp> #include <vespa/document/fieldset/fieldsets.h> #include <vespa/document/repo/documenttyperepo.h> @@ -92,16 +93,16 @@ struct FileStorTestBase : Test { std::unique_ptr<vdstestlib::DirConfig> config; std::unique_ptr<vdstestlib::DirConfig> config2; std::unique_ptr<vdstestlib::DirConfig> smallConfig; - const uint32_t _waitTime; + const int32_t _waitTime; const document::DocumentType* _testdoctype1; - FileStorTestBase() : _node(), _waitTime(LONG_WAITTIME) {} + FileStorTestBase(); ~FileStorTestBase() override; void SetUp() override; void TearDown() override; - void createBucket(document::BucketId bid) { + void createBucket(document::BucketId bid) const { _node->getPersistenceProvider().createBucket(makeSpiBucket(bid)); StorBucketDatabase::WrappedEntry entry( @@ -110,13 +111,13 @@ struct FileStorTestBase : Test { entry.write(); } - document::Document::UP createDocument(const std::string& content, const std::string& id) { + document::Document::UP createDocument(const std::string& content, const std::string& id) const { return _node->getTestDocMan().createDocument(content, id); } std::shared_ptr<api::PutCommand> make_put_command(StorageMessage::Priority pri = 20, const std::string& docid = "id:foo:testdoctype1::bar", - Timestamp timestamp = 100) { + Timestamp timestamp = 100) const { Document::SP doc(createDocument("my content", docid)); auto bucket = make_bucket_for_doc(doc->getId()); auto cmd = std::make_shared<api::PutCommand>(bucket, std::move(doc), timestamp); @@ -124,7 +125,7 @@ struct FileStorTestBase : Test { return cmd; } - std::shared_ptr<api::GetCommand> make_get_command(StorageMessage::Priority pri, + static std::shared_ptr<api::GetCommand> make_get_command(StorageMessage::Priority pri, const std::string& docid = "id:foo:testdoctype1::bar") { document::DocumentId did(docid); auto bucket = make_bucket_for_doc(did); @@ -139,12 +140,11 @@ struct FileStorTestBase : Test { auto clusterStateBundle = _node->getStateUpdater().getClusterStateBundle(); const auto &clusterState = *clusterStateBundle->getBaselineClusterState(); uint16_t distributor( - _node->getDistribution()->getIdealDistributorNode( - clusterState, bucket)); + _node->getDistribution()->getIdealDistributorNode(clusterState, bucket)); return distributor == distributorIndex; } - document::BucketId getFirstBucketNotOwnedByDistributor(uint16_t distributor) { + document::BucketId getFirstBucketNotOwnedByDistributor(uint16_t distributor) const { for (int i = 0; i < 1000; ++i) { if (!ownsBucket(distributor, document::BucketId(16, i))) { return document::BucketId(16, i); @@ -153,28 +153,25 @@ struct FileStorTestBase : Test { return document::BucketId(0); } - spi::dummy::DummyPersistence& getDummyPersistence() { + spi::dummy::DummyPersistence& getDummyPersistence() const { return dynamic_cast<spi::dummy::DummyPersistence&>(_node->getPersistenceProvider()); } - void setClusterState(const std::string& state) { - _node->getStateUpdater().setClusterState( - lib::ClusterState::CSP( - new lib::ClusterState(state))); + void setClusterState(const std::string& state) const { + _node->getStateUpdater().setClusterState(lib::ClusterState::CSP(new lib::ClusterState(state))); } void setupDisks() { std::string rootOfRoot = "filestormanagertest"; - config.reset(new vdstestlib::DirConfig(getStandardConfig(true, rootOfRoot))); + config = std::make_unique<vdstestlib::DirConfig>(getStandardConfig(true, rootOfRoot)); - config2.reset(new vdstestlib::DirConfig(*config)); + config2 = std::make_unique<vdstestlib::DirConfig>(*config); config2->getConfig("stor-server").set("root_folder", rootOfRoot + "-vdsroot.2"); config2->getConfig("stor-devices").set("root_folder", rootOfRoot + "-vdsroot.2"); config2->getConfig("stor-server").set("node_index", "1"); - smallConfig.reset(new vdstestlib::DirConfig(*config)); - vdstestlib::DirConfig::Config& c( - smallConfig->getConfig("stor-filestor", true)); + smallConfig = std::make_unique<vdstestlib::DirConfig>(*config); + vdstestlib::DirConfig::Config& c(smallConfig->getConfig("stor-filestor", true)); c.set("initial_index_read", "128"); c.set("use_direct_io", "false"); c.set("maximum_gap_to_read_through", "64"); @@ -202,11 +199,16 @@ struct FileStorTestBase : Test { std::shared_ptr<api::StorageMessage> cmd, const Metric& metric); - auto& thread_metrics_of(FileStorManager& manager) { + static auto& thread_metrics_of(FileStorManager& manager) { return manager.get_metrics().threads[0]; } }; +FileStorTestBase::FileStorTestBase() + : _node(), + _waitTime(LONG_WAITTIME), + _testdoctype1(nullptr) +{} FileStorTestBase::~FileStorTestBase() = default; std::unique_ptr<DiskThread> @@ -243,7 +245,8 @@ struct FileStorHandlerComponents { FileStorMetrics metrics; std::unique_ptr<FileStorHandler> filestorHandler; - FileStorHandlerComponents(FileStorTestBase& test, uint32_t threadsPerDisk = 1) + explicit FileStorHandlerComponents(FileStorTestBase& test) : FileStorHandlerComponents(test, 1) {} + FileStorHandlerComponents(FileStorTestBase& test, uint32_t threadsPerDisk) : top(), dummyManager(new DummyStorageLink), messageSender(*dummyManager), @@ -269,7 +272,7 @@ struct PersistenceHandlerComponents : public FileStorHandlerComponents { BucketOwnershipNotifier bucketOwnershipNotifier; std::unique_ptr<PersistenceHandler> persistenceHandler; - PersistenceHandlerComponents(FileStorTestBase& test) + explicit PersistenceHandlerComponents(FileStorTestBase& test) : FileStorHandlerComponents(test), executor(test._node->executor()), component(test._node->getComponentRegister(), "test"), @@ -311,7 +314,6 @@ FileStorTestBase::TearDown() } struct FileStorManagerTest : public FileStorTestBase { - void do_test_delete_bucket(bool use_throttled_delete); }; TEST_F(FileStorManagerTest, header_only_put) { @@ -767,8 +769,8 @@ TEST_F(FileStorManagerTest, priority) { document::BucketIdFactory factory; // Create buckets in separate, initial pass to avoid races with puts - for (uint32_t i=0; i<documents.size(); ++i) { - document::BucketId bucket(16, factory.getBucketId(documents[i]->getId()).getRawId()); + for (const auto & document : documents) { + document::BucketId bucket(16, factory.getBucketId(document->getId()).getRawId()); _node->getPersistenceProvider().createBucket(makeSpiBucket(bucket)); } @@ -980,9 +982,9 @@ TEST_F(FileStorManagerTest, split_single_group) { } // Test that the documents are all still there - for (uint32_t i=0; i<documents.size(); ++i) { + for (const auto & document : documents) { document::BucketId bucket(17, state ? 0x10001 : 0x00001); - auto cmd = std::make_shared<api::GetCommand>(makeDocumentBucket(bucket), documents[i]->getId(), document::AllFields::NAME); + auto cmd = std::make_shared<api::GetCommand>(makeDocumentBucket(bucket), document->getId(), document::AllFields::NAME); cmd->setAddress(_storage3); filestorHandler.schedule(cmd); filestorHandler.flush(true); @@ -1159,8 +1161,8 @@ TEST_F(FileStorManagerTest, join) { // Perform a join, check that other files are gone { auto cmd = std::make_shared<api::JoinBucketsCommand>(makeDocumentBucket(document::BucketId(16, 1))); - cmd->getSourceBuckets().emplace_back(document::BucketId(17, 0x00001)); - cmd->getSourceBuckets().emplace_back(document::BucketId(17, 0x10001)); + cmd->getSourceBuckets().emplace_back(17, 0x00001); + cmd->getSourceBuckets().emplace_back(17, 0x10001); filestorHandler.schedule(cmd); filestorHandler.flush(true); ASSERT_EQ(1, top.getNumReplies()); @@ -1371,12 +1373,11 @@ TEST_F(FileStorManagerTest, remove_location) { } } -void FileStorManagerTest::do_test_delete_bucket(bool use_throttled_delete) { +TEST_F(FileStorManagerTest, delete_bucket) { TestFileStorComponents c(*this); auto config_uri = config::ConfigUri(config->getConfigId()); StorFilestorConfigBuilder my_config(*config_from<StorFilestorConfig>(config_uri)); - my_config.usePerDocumentThrottledDeleteBucket = use_throttled_delete; c.manager->on_configure(my_config); auto& top = c.top; @@ -1421,23 +1422,12 @@ void FileStorManagerTest::do_test_delete_bucket(bool use_throttled_delete) { StorBucketDatabase::WrappedEntry entry(_node->getStorageBucketDatabase().get(bid, "foo")); EXPECT_FALSE(entry.exists()); } - if (use_throttled_delete) { - auto& metrics = thread_metrics_of(*c.manager)->remove_by_gid; - EXPECT_EQ(metrics.failed.getValue(), 0); - EXPECT_EQ(metrics.count.getValue(), 1); - // We can't reliably test the actual latency here without wiring mock clock bumping into - // the async remove by GID execution, but we can at least test that we updated the metric. - EXPECT_EQ(metrics.latency.getCount(), 1); - } -} - -// TODO remove once throttled behavior is the default -TEST_F(FileStorManagerTest, delete_bucket_legacy) { - do_test_delete_bucket(false); -} - -TEST_F(FileStorManagerTest, delete_bucket_throttled) { - do_test_delete_bucket(true); + auto& metrics = thread_metrics_of(*c.manager)->remove_by_gid; + EXPECT_EQ(metrics.failed.getValue(), 0); + EXPECT_EQ(metrics.count.getValue(), 1); + // We can't reliably test the actual latency here without wiring mock clock bumping into + // the async remove by GID execution, but we can at least test that we updated the metric. + EXPECT_EQ(metrics.latency.getCount(), 1); } TEST_F(FileStorManagerTest, delete_bucket_rejects_outdated_bucket_info) { diff --git a/storage/src/tests/storageserver/bouncertest.cpp b/storage/src/tests/storageserver/bouncertest.cpp index 225b3c94120..296ed6d23bc 100644 --- a/storage/src/tests/storageserver/bouncertest.cpp +++ b/storage/src/tests/storageserver/bouncertest.cpp @@ -4,6 +4,7 @@ #include <tests/common/testhelper.h> #include <tests/common/teststorageapp.h> #include <vespa/config/common/exceptions.h> +#include <memory> #include <vespa/config/helper/configgetter.hpp> #include <vespa/document/bucket/fixed_bucket_spaces.h> #include <vespa/document/fieldset/fieldsets.h> @@ -41,18 +42,6 @@ struct BouncerTest : public Test { static constexpr int RejectionDisabledConfigValue = -1; - // Note: newThreshold is intentionally int (rather than Priority) in order - // to be able to test out of bounds values. - void configureRejectionThreshold(int newThreshold); - - std::shared_ptr<api::StorageCommand> createDummyFeedMessage( - api::Timestamp timestamp, - Priority priority = 0); - - std::shared_ptr<api::StorageCommand> createDummyFeedMessage( - api::Timestamp timestamp, - document::BucketSpace bucketSpace); - void expectMessageBouncedWithRejection() const; void expect_message_bounced_with_node_down_abort() const; void expect_message_bounced_with_shutdown_abort() const; @@ -70,11 +59,11 @@ BouncerTest::BouncerTest() void BouncerTest::setUpAsNode(const lib::NodeType& type) { vdstestlib::DirConfig config(getStandardConfig(type == lib::NodeType::STORAGE)); if (type == lib::NodeType::STORAGE) { - _node.reset(new TestServiceLayerApp(NodeIndex(2), config.getConfigId())); + _node = std::make_unique<TestServiceLayerApp>(NodeIndex(2), config.getConfigId()); } else { - _node.reset(new TestDistributorApp(NodeIndex(2), config.getConfigId())); + _node = std::make_unique<TestDistributorApp>(NodeIndex(2), config.getConfigId()); } - _upper.reset(new DummyStorageLink()); + _upper = std::make_unique<DummyStorageLink>(); using StorBouncerConfig = vespa::config::content::core::StorBouncerConfig; auto cfg_uri = config::ConfigUri(config.getConfigId()); auto cfg = config::ConfigGetter<StorBouncerConfig>::getConfig(cfg_uri.getConfigId(), cfg_uri.getContext()); @@ -104,8 +93,8 @@ BouncerTest::TearDown() { } std::shared_ptr<api::StorageCommand> -BouncerTest::createDummyFeedMessage(api::Timestamp timestamp, - api::StorageMessage::Priority priority) +createDummyFeedMessage(api::Timestamp timestamp, + api::StorageMessage::Priority priority = 0) { auto cmd = std::make_shared<api::RemoveCommand>( makeDocumentBucket(document::BucketId(0)), @@ -116,14 +105,14 @@ BouncerTest::createDummyFeedMessage(api::Timestamp timestamp, } std::shared_ptr<api::StorageCommand> -BouncerTest::createDummyFeedMessage(api::Timestamp timestamp, - document::BucketSpace bucketSpace) +createDummyFeedMessage(api::Timestamp timestamp, + document::BucketSpace bucketSpace) { auto cmd = std::make_shared<api::RemoveCommand>( document::Bucket(bucketSpace, document::BucketId(0)), document::DocumentId("id:ns:foo::bar"), timestamp); - cmd->setPriority(Priority(0)); + cmd->setPriority(BouncerTest::Priority(0)); return cmd; } @@ -226,58 +215,21 @@ BouncerTest::expectMessageNotBounced() const EXPECT_EQ(size_t(1), _lower->getNumCommands()); } -void -BouncerTest::configureRejectionThreshold(int newThreshold) -{ - using Builder = vespa::config::content::core::StorBouncerConfigBuilder; - Builder config; - config.feedRejectionPriorityThreshold = newThreshold; - _manager->on_configure(config); -} - -TEST_F(BouncerTest, reject_lower_prioritized_feed_messages_when_configured) { - configureRejectionThreshold(Priority(120)); - _upper->sendDown(createDummyFeedMessage(11 * 1000000, Priority(121))); - expectMessageBouncedWithRejection(); -} - -TEST_F(BouncerTest, do_not_reject_higher_prioritized_feed_messages_than_configured) { - configureRejectionThreshold(Priority(120)); - _upper->sendDown(createDummyFeedMessage(11 * 1000000, Priority(119))); - expectMessageNotBounced(); -} - -TEST_F(BouncerTest, priority_rejection_threshold_is_exclusive) { - configureRejectionThreshold(Priority(120)); - _upper->sendDown(createDummyFeedMessage(11 * 1000000, Priority(120))); - expectMessageNotBounced(); -} - -TEST_F(BouncerTest, only_priority_reject_feed_messages_when_configured) { - configureRejectionThreshold(RejectionDisabledConfigValue); - // A message with even the lowest priority should not be rejected. - _upper->sendDown(createDummyFeedMessage(11 * 1000000, Priority(255))); - expectMessageNotBounced(); -} - TEST_F(BouncerTest, priority_rejection_is_disabled_by_default_in_config) { _upper->sendDown(createDummyFeedMessage(11 * 1000000, Priority(255))); expectMessageNotBounced(); } -TEST_F(BouncerTest, read_only_operations_are_not_priority_rejected) { - configureRejectionThreshold(Priority(1)); +TEST_F(BouncerTest, read_only_operations_are_not_rejected) { // StatBucket is an external operation, but it's not a mutating operation // and should therefore not be blocked. - auto cmd = std::make_shared<api::StatBucketCommand>( - makeDocumentBucket(document::BucketId(16, 5)), ""); + auto cmd = std::make_shared<api::StatBucketCommand>(makeDocumentBucket(document::BucketId(16, 5)), ""); cmd->setPriority(Priority(2)); _upper->sendDown(cmd); expectMessageNotBounced(); } TEST_F(BouncerTest, internal_operations_are_not_rejected) { - configureRejectionThreshold(Priority(1)); document::BucketId bucket(16, 1234); api::BucketInfo info(0x1, 0x2, 0x3); auto cmd = std::make_shared<api::NotifyBucketChangeCommand>(makeDocumentBucket(bucket), info); @@ -286,12 +238,6 @@ TEST_F(BouncerTest, internal_operations_are_not_rejected) { expectMessageNotBounced(); } -TEST_F(BouncerTest, out_of_bounds_config_values_throw_exception) { - EXPECT_THROW(configureRejectionThreshold(256), config::InvalidConfigException); - EXPECT_THROW(configureRejectionThreshold(-2), config::InvalidConfigException); -} - - namespace { std::shared_ptr<const lib::ClusterStateBundle> diff --git a/storage/src/tests/visiting/visitormanagertest.cpp b/storage/src/tests/visiting/visitormanagertest.cpp index 5fa6d4a77d8..64d2042b61a 100644 --- a/storage/src/tests/visiting/visitormanagertest.cpp +++ b/storage/src/tests/visiting/visitormanagertest.cpp @@ -1032,7 +1032,7 @@ TEST_F(VisitorManagerTest, status_page) { EXPECT_THAT(str, HasSubstr("Running 1 visitors")); // 1 active EXPECT_THAT(str, HasSubstr("waiting visitors 1")); // 1 queued EXPECT_THAT(str, HasSubstr("Visitor thread 0")); - EXPECT_THAT(str, HasSubstr("Disconnected visitor timeout")); // verbose per thread + EXPECT_THAT(str, HasSubstr("Iterators per bucket")); // verbose per thread EXPECT_THAT(str, HasSubstr("Message #1 <b>putdocumentmessage</b>")); // 1 active for (uint32_t session = 0; session < 2 ; ++session){ diff --git a/storage/src/vespa/storage/bucketdb/.gitignore b/storage/src/vespa/storage/bucketdb/.gitignore index 3df72b601a2..333f254ba10 100644 --- a/storage/src/vespa/storage/bucketdb/.gitignore +++ b/storage/src/vespa/storage/bucketdb/.gitignore @@ -6,4 +6,3 @@ .deps .libs Makefile -config-stor-bucketdb.* diff --git a/storage/src/vespa/storage/bucketdb/CMakeLists.txt b/storage/src/vespa/storage/bucketdb/CMakeLists.txt index 0fc32f11583..f9f6220ec1e 100644 --- a/storage/src/vespa/storage/bucketdb/CMakeLists.txt +++ b/storage/src/vespa/storage/bucketdb/CMakeLists.txt @@ -13,5 +13,3 @@ vespa_add_library(storage_bucketdb OBJECT striped_btree_lockable_map.cpp DEPENDS ) -vespa_generate_config(storage_bucketdb stor-bucketdb.def) -install_config_definition(stor-bucketdb.def vespa.config.content.core.stor-bucketdb.def) diff --git a/storage/src/vespa/storage/bucketdb/bucketinfo.h b/storage/src/vespa/storage/bucketdb/bucketinfo.h index 8f9b3d3486a..8b37c50e00e 100644 --- a/storage/src/vespa/storage/bucketdb/bucketinfo.h +++ b/storage/src/vespa/storage/bucketdb/bucketinfo.h @@ -91,7 +91,7 @@ public: /** * Returns the number of nodes this entry has. */ - uint32_t getNodeCount() const noexcept { return static_cast<uint32_t>(_nodes.size()); } + uint16_t getNodeCount() const noexcept { return static_cast<uint16_t>(_nodes.size()); } /** * Returns a list of the nodes this entry has. diff --git a/storage/src/vespa/storage/bucketdb/bucketmanager.cpp b/storage/src/vespa/storage/bucketdb/bucketmanager.cpp index d12a9f72ac1..5337be6d79f 100644 --- a/storage/src/vespa/storage/bucketdb/bucketmanager.cpp +++ b/storage/src/vespa/storage/bucketdb/bucketmanager.cpp @@ -148,12 +148,13 @@ DistributorInfoGatherer::operator()(uint64_t bucketId, const StorBucketDatabase: struct MetricsUpdater { struct Count { uint64_t docs; + uint64_t entries; // docs + tombstones uint64_t bytes; uint64_t buckets; uint64_t active; uint64_t ready; - constexpr Count() noexcept : docs(0), bytes(0), buckets(0), active(0), ready(0) {} + constexpr Count() noexcept : docs(0), entries(0), bytes(0), buckets(0), active(0), ready(0) {} }; Count count; uint32_t lowestUsedBit; @@ -174,8 +175,9 @@ struct MetricsUpdater { if (data.getBucketInfo().isReady()) { ++count.ready; } - count.docs += data.getBucketInfo().getDocumentCount(); - count.bytes += data.getBucketInfo().getTotalDocumentSize(); + count.docs += data.getBucketInfo().getDocumentCount(); + count.entries += data.getBucketInfo().getMetaCount(); + count.bytes += data.getBucketInfo().getTotalDocumentSize(); if (bucket.getUsedBits() < lowestUsedBit) { lowestUsedBit = bucket.getUsedBits(); @@ -188,6 +190,7 @@ struct MetricsUpdater { const auto& s = rhs.count; d.buckets += s.buckets; d.docs += s.docs; + d.entries += s.entries; d.bytes += s.bytes; d.ready += s.ready; d.active += s.active; @@ -234,11 +237,15 @@ BucketManager::report(vespalib::JsonStream & json) const { MetricsUpdater m = getMetrics(space.second->bucketDatabase()); output(json, "vds.datastored.bucket_space.buckets_total", m.count.buckets, document::FixedBucketSpaces::to_string(space.first)); + output(json, "vds.datastored.bucket_space.entries", m.count.entries, + document::FixedBucketSpaces::to_string(space.first)); + output(json, "vds.datastored.bucket_space.docs", m.count.docs, + document::FixedBucketSpaces::to_string(space.first)); total.add(m); } const auto & src = total.count; - output(json, "vds.datastored.alldisks.docs", src.docs); - output(json, "vds.datastored.alldisks.bytes", src.bytes); + output(json, "vds.datastored.alldisks.docs", src.docs); + output(json, "vds.datastored.alldisks.bytes", src.bytes); output(json, "vds.datastored.alldisks.buckets", src.buckets); } @@ -258,6 +265,7 @@ BucketManager::updateMetrics() const auto bm = _metrics->bucket_spaces.find(space.first); assert(bm != _metrics->bucket_spaces.end()); bm->second->buckets_total.set(m.count.buckets); + bm->second->entries.set(m.count.entries); bm->second->docs.set(m.count.docs); bm->second->bytes.set(m.count.bytes); bm->second->active_buckets.set(m.count.active); diff --git a/storage/src/vespa/storage/bucketdb/bucketmanagermetrics.cpp b/storage/src/vespa/storage/bucketdb/bucketmanagermetrics.cpp index ca9e556f83c..d2b019cc50d 100644 --- a/storage/src/vespa/storage/bucketdb/bucketmanagermetrics.cpp +++ b/storage/src/vespa/storage/bucketdb/bucketmanagermetrics.cpp @@ -31,6 +31,7 @@ ContentBucketDbMetrics::~ContentBucketDbMetrics() = default; BucketSpaceMetrics::BucketSpaceMetrics(const vespalib::string& space_name, metrics::MetricSet* owner) : metrics::MetricSet("bucket_space", {{"bucketSpace", space_name}}, "", owner), buckets_total("buckets_total", {}, "Total number buckets present in the bucket space (ready + not ready)", this), + entries("entries", {}, "Number of entries (documents + tombstones) stored in the bucket space", this), docs("docs", {}, "Documents stored in the bucket space", this), bytes("bytes", {}, "Bytes stored across all documents in the bucket space", this), active_buckets("active_buckets", {}, "Number of active buckets in the bucket space", this), diff --git a/storage/src/vespa/storage/bucketdb/bucketmanagermetrics.h b/storage/src/vespa/storage/bucketdb/bucketmanagermetrics.h index a73bb676526..cab3a397c54 100644 --- a/storage/src/vespa/storage/bucketdb/bucketmanagermetrics.h +++ b/storage/src/vespa/storage/bucketdb/bucketmanagermetrics.h @@ -34,6 +34,7 @@ struct ContentBucketDbMetrics : metrics::MetricSet { struct BucketSpaceMetrics : metrics::MetricSet { // Superficially very similar to DataStoredMetrics, but metric naming and dimensions differ metrics::LongValueMetric buckets_total; + metrics::LongValueMetric entries; metrics::LongValueMetric docs; metrics::LongValueMetric bytes; metrics::LongValueMetric active_buckets; diff --git a/storage/src/vespa/storage/bucketdb/stor-bucketdb.def b/storage/src/vespa/storage/bucketdb/stor-bucketdb.def deleted file mode 100644 index 16a0473fe4a..00000000000 --- a/storage/src/vespa/storage/bucketdb/stor-bucketdb.def +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -namespace=vespa.config.content.core - -## Number of elements to retrieve in one bucket info chunk -bucketinfobatchsize int default=128 restart - -## Chunk level. Set what level of the path which defines one chunk. -## (See doxygen info in bucketmanager.h for more info) -chunklevel int default=1 restart diff --git a/storage/src/vespa/storage/config/CMakeLists.txt b/storage/src/vespa/storage/config/CMakeLists.txt index 8eaf4f359f4..670569ed598 100644 --- a/storage/src/vespa/storage/config/CMakeLists.txt +++ b/storage/src/vespa/storage/config/CMakeLists.txt @@ -12,8 +12,6 @@ vespa_generate_config(storage_storageconfig stor-server.def) install_config_definition(stor-server.def vespa.config.content.core.stor-server.def) vespa_generate_config(storage_storageconfig stor-status.def) install_config_definition(stor-status.def vespa.config.content.core.stor-status.def) -vespa_generate_config(storage_storageconfig stor-opslogger.def) -install_config_definition(stor-opslogger.def vespa.config.content.core.stor-opslogger.def) vespa_generate_config(storage_storageconfig stor-visitordispatcher.def) install_config_definition(stor-visitordispatcher.def vespa.config.content.core.stor-visitordispatcher.def) vespa_generate_config(storage_storageconfig stor-bouncer.def) diff --git a/storage/src/vespa/storage/config/distributorconfiguration.cpp b/storage/src/vespa/storage/config/distributorconfiguration.cpp index 83be5d71b23..7800eb625e3 100644 --- a/storage/src/vespa/storage/config/distributorconfiguration.cpp +++ b/storage/src/vespa/storage/config/distributorconfiguration.cpp @@ -47,7 +47,6 @@ DistributorConfiguration::DistributorConfiguration(StorageComponent& component) _merge_operations_disabled(false), _use_weak_internal_read_consistency_for_client_gets(false), _enable_metadata_only_fetch_phase_for_inconsistent_updates(false), - _prioritize_global_bucket_merges(true), _implicitly_clear_priority_on_schedule(false), _use_unordered_merge_chaining(false), _inhibit_default_merges_when_global_merges_pending(false), @@ -96,25 +95,6 @@ DistributorConfiguration::containsTimeStatement(const std::string& documentSelec return visitor.hasCurrentTime; } -void -DistributorConfiguration::configureMaintenancePriorities( - const vespa::config::content::core::StorDistributormanagerConfig& cfg) -{ - MaintenancePriorities& mp(_maintenancePriorities); - mp.mergeMoveToIdealNode = cfg.priorityMergeMoveToIdealNode; - mp.mergeOutOfSyncCopies = cfg.priorityMergeOutOfSyncCopies; - mp.mergeTooFewCopies = cfg.priorityMergeTooFewCopies; - mp.mergeGlobalBuckets = cfg.priorityMergeGlobalBuckets; - mp.activateNoExistingActive = cfg.priorityActivateNoExistingActive; - mp.activateWithExistingActive = cfg.priorityActivateWithExistingActive; - mp.deleteBucketCopy = cfg.priorityDeleteBucketCopy; - mp.joinBuckets = cfg.priorityJoinBuckets; - mp.splitDistributionBits = cfg.prioritySplitDistributionBits; - mp.splitLargeBucket = cfg.prioritySplitLargeBucket; - mp.splitInconsistentBucket = cfg.prioritySplitInconsistentBucket; - mp.garbageCollection = cfg.priorityGarbageCollection; -} - void DistributorConfiguration::configure(const vespa::config::content::core::StorDistributormanagerConfig& config) { @@ -154,11 +134,6 @@ DistributorConfiguration::configure(const vespa::config::content::core::StorDist _garbageCollectionInterval = vespalib::duration::zero(); } - _blockedStateCheckers.clear(); - for (uint32_t i = 0; i < config.blockedstatecheckers.size(); ++i) { - _blockedStateCheckers.insert(config.blockedstatecheckers[i]); - } - _doInlineSplit = config.inlinebucketsplitting; _enableJoinForSiblingLessBuckets = config.enableJoinForSiblingLessBuckets; _enableInconsistentJoin = config.enableInconsistentJoin; @@ -171,7 +146,6 @@ DistributorConfiguration::configure(const vespa::config::content::core::StorDist _merge_operations_disabled = config.mergeOperationsDisabled; _use_weak_internal_read_consistency_for_client_gets = config.useWeakInternalReadConsistencyForClientGets; _enable_metadata_only_fetch_phase_for_inconsistent_updates = config.enableMetadataOnlyFetchPhaseForInconsistentUpdates; - _prioritize_global_bucket_merges = config.prioritizeGlobalBucketMerges; _max_activation_inhibited_out_of_sync_groups = config.maxActivationInhibitedOutOfSyncGroups; _implicitly_clear_priority_on_schedule = config.implicitlyClearBucketPriorityOnSchedule; _use_unordered_merge_chaining = config.useUnorderedMergeChaining; @@ -179,11 +153,8 @@ DistributorConfiguration::configure(const vespa::config::content::core::StorDist _enable_two_phase_garbage_collection = config.enableTwoPhaseGarbageCollection; _enable_condition_probing = config.enableConditionProbing; _enable_operation_cancellation = config.enableOperationCancellation; - _minimumReplicaCountingMode = config.minimumReplicaCountingMode; - configureMaintenancePriorities(config); - if (config.maxClusterClockSkewSec >= 0) { _maxClusterClockSkew = std::chrono::seconds(config.maxClusterClockSkewSec); } diff --git a/storage/src/vespa/storage/config/distributorconfiguration.h b/storage/src/vespa/storage/config/distributorconfiguration.h index 9d879fa62d5..330567953bd 100644 --- a/storage/src/vespa/storage/config/distributorconfiguration.h +++ b/storage/src/vespa/storage/config/distributorconfiguration.h @@ -12,21 +12,20 @@ namespace storage { namespace distributor { struct LegacyDistributorTest; } class DistributorConfiguration { -public: +public: + DistributorConfiguration(const DistributorConfiguration& other) = delete; + DistributorConfiguration& operator=(const DistributorConfiguration& other) = delete; explicit DistributorConfiguration(StorageComponent& component); ~DistributorConfiguration(); - struct MaintenancePriorities - { - // Defaults for these are chosen as those used as the current (non- - // configurable) values at the time of implementation. - uint8_t mergeMoveToIdealNode {120}; + struct MaintenancePriorities { + uint8_t mergeMoveToIdealNode {165}; uint8_t mergeOutOfSyncCopies {120}; uint8_t mergeTooFewCopies {120}; uint8_t mergeGlobalBuckets {115}; uint8_t activateNoExistingActive {100}; uint8_t activateWithExistingActive {100}; - uint8_t deleteBucketCopy {100}; + uint8_t deleteBucketCopy {120}; uint8_t joinBuckets {155}; uint8_t splitDistributionBits {200}; uint8_t splitLargeBucket {175}; @@ -58,14 +57,6 @@ public: _lastGarbageCollectionChange = lastChangeTime; } - bool stateCheckerIsActive(vespalib::stringref stateCheckerName) const { - return _blockedStateCheckers.find(stateCheckerName) == _blockedStateCheckers.end(); - } - - void disableStateChecker(vespalib::stringref stateCheckerName) { - _blockedStateCheckers.insert(stateCheckerName); - } - void setDoInlineSplit(bool value) { _doInlineSplit = value; } @@ -249,13 +240,6 @@ public: return _max_consecutively_inhibited_maintenance_ticks; } - void set_prioritize_global_bucket_merges(bool prioritize) noexcept { - _prioritize_global_bucket_merges = prioritize; - } - bool prioritize_global_bucket_merges() const noexcept { - return _prioritize_global_bucket_merges; - } - void set_max_activation_inhibited_out_of_sync_groups(uint32_t max_groups) noexcept { _max_activation_inhibited_out_of_sync_groups = max_groups; } @@ -302,9 +286,6 @@ public: bool containsTimeStatement(const std::string& documentSelection) const; private: - DistributorConfiguration(const DistributorConfiguration& other); - DistributorConfiguration& operator=(const DistributorConfiguration& other); - StorageComponent& _component; uint32_t _byteCountSplitLimit; @@ -326,8 +307,6 @@ private: uint32_t _minPendingMaintenanceOps; uint32_t _maxPendingMaintenanceOps; - vespalib::hash_set<vespalib::string> _blockedStateCheckers; - uint32_t _maxVisitorsPerNodePerClientVisitor; uint32_t _minBucketsPerVisitor; @@ -350,7 +329,6 @@ private: bool _merge_operations_disabled; bool _use_weak_internal_read_consistency_for_client_gets; bool _enable_metadata_only_fetch_phase_for_inconsistent_updates; - bool _prioritize_global_bucket_merges; bool _implicitly_clear_priority_on_schedule; bool _use_unordered_merge_chaining; bool _inhibit_default_merges_when_global_merges_pending; @@ -359,10 +337,6 @@ private: bool _enable_operation_cancellation; DistrConfig::MinimumReplicaCountingMode _minimumReplicaCountingMode; - - friend struct distributor::LegacyDistributorTest; - void configureMaintenancePriorities( - const vespa::config::content::core::StorDistributormanagerConfig&); }; } diff --git a/storage/src/vespa/storage/config/stor-bouncer.def b/storage/src/vespa/storage/config/stor-bouncer.def index 1327a6433c8..aa65c31c186 100644 --- a/storage/src/vespa/storage/config/stor-bouncer.def +++ b/storage/src/vespa/storage/config/stor-bouncer.def @@ -1,30 +1,6 @@ # Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. namespace=vespa.config.content.core -## Whether or not the bouncer should stop external load from -## entering node when the cluster state is down. -stop_external_load_when_cluster_down bool default=true - -## Sets what node states the node will allow incoming commands -## in. -stop_all_load_when_nodestate_not_in string default="uri" - -## Sets whether to just use (self) reported node state or to use wanted state -## if wanted state is worse than the current reported state. -use_wanted_state_if_possible bool default=true - ## The maximum clock skew allowed in the system. Any messages received ## that have a timestamp longer in the future than this will be failed. max_clock_skew_seconds int default=5 - -## If this config value is != -1, the node will reject any external feed -## operations with a priority lower than that specified here. Note that since -## we map priorities in such a way that 0 is the _highest_ priority and 255 the -## _lowest_ priority, for two operations A and B, if B has a lower priority -## than A it will have a higher priority _integer_ value. -## -## Only mutating external feed operations will be blocked. Read-only operations -## and internal operations are always let through. -## -## Default is -1 (i.e. rejection is disabled and load is allowed through) -feed_rejection_priority_threshold int default=-1 diff --git a/storage/src/vespa/storage/config/stor-communicationmanager.def b/storage/src/vespa/storage/config/stor-communicationmanager.def index a1ce8d4e47b..92ae38ea7c6 100644 --- a/storage/src/vespa/storage/config/stor-communicationmanager.def +++ b/storage/src/vespa/storage/config/stor-communicationmanager.def @@ -41,16 +41,6 @@ mbus.num_network_threads int default=1 restart ## The number of events in the queue of a network (FNET) thread before it is woken up. mbus.events_before_wakeup int default=1 restart -## Enable to use above thread pool for encoding replies -## False will use network(fnet) thread -## Deprecated and void -mbus.dispatch_on_encode bool default=true restart - -## Enable to use above thread pool for decoding replies -## False will use network(fnet) thread -## Deprecated and void -mbus.dispatch_on_decode bool default=true restart - ## The number of network (FNET) threads used by the shared rpc resource. rpc.num_network_threads int default=2 restart diff --git a/storage/src/vespa/storage/config/stor-distributormanager.def b/storage/src/vespa/storage/config/stor-distributormanager.def index f40e572e2e3..3136a54d080 100644 --- a/storage/src/vespa/storage/config/stor-distributormanager.def +++ b/storage/src/vespa/storage/config/stor-distributormanager.def @@ -39,22 +39,6 @@ garbagecollection.interval int default=0 ## If false, dont do splits inline with feeding. inlinebucketsplitting bool default=true -## List of state checkers (ideal state generators) that should be ignored in the cluster. -## One or more of the following (case insensitive): -## -## SynchronizeAndMove -## DeleteExtraCopies -## JoinBuckets -## SplitBucket -## SplitInconsistentBuckets -## SetBucketState -## GarbageCollection -blockedstatecheckers[] string restart - -## Whether or not distributor should issue reverts when operations partially -## fail. -enable_revert bool default=true - ## Maximum nodes involved in a merge operation. Currently, this can not be more ## than 16 nodes due to protocol limitations. However, decreasing the max may ## be useful if 16 node merges ends up too expensive. @@ -64,66 +48,6 @@ maximum_nodes_per_merge int default=16 ## distributor thread to be able to call tick() manually and run single threaded start_distributor_thread bool default=true restart -## The number of ticks calls done before a wait is done. This can be -## set higher than 10 for the distributor to improve speed of bucket iterations -## while still keep CPU load low/moderate. -ticks_before_wait int default=10 - -## The sleep time between ticks if there are no more queued tasks. -ticks_wait_time_ms int default=1 - -## Max processing time used by deadlock detector. -max_process_time_ms int default=5000 - -## Allow overriding default priorities of certain maintenance operations. -## This is an advanced feature, do not touch this unless you have a very good -## reason to do so! Configuring these values wrongly may cause starvation of -## important operations, leading to unpredictable behavior and/or data loss. -## -## Merge used to move data to ideal location -priority_merge_move_to_ideal_node int default=165 - -## Merge for copies that have gotten out of sync with each other -priority_merge_out_of_sync_copies int default=120 - -## Merge for restoring redundancy of copies -priority_merge_too_few_copies int default=120 - -## Merge that involves a global bucket. There are generally significantly fewer such -## buckets than default-space buckets, and searches to documents in the default space -## may depend on the presence of (all) global documents. Consequently, this priority -## should be higher (i.e. numerically smaller) than that of regular merges. -priority_merge_global_buckets int default=115 - -## Copy activation when there are no other active copies (likely causing -## lack of search coverage for that bucket) -priority_activate_no_existing_active int default=100 - -## Copy activation when there is already an active copy for the bucket. -priority_activate_with_existing_active int default=100 - -## Deletion of bucket copy. -priority_delete_bucket_copy int default=120 - -## Joining caused by bucket siblings getting sufficiently small to fit into a -## single bucket. -priority_join_buckets int default=155 - -## Splitting caused by system increasing its minimum distribution bit count. -priority_split_distribution_bits int default=200 - -## Splitting due to bucket exceeding max document count or byte size (see -## splitcount and splitsize config values) -priority_split_large_bucket int default=175 - -## Splitting due to buckets being inconsistently split. Should be higher -## priority than the vast majority of external load. -priority_split_inconsistent_bucket int default=110 - -## Background garbage collection. Should be lower priority than external load -## and other ideal state operations (aside from perhaps minimum bit splitting). -priority_garbage_collection int default=200 - ## The distributor can send joins that "lift" a bucket without any siblings ## higher up in the bucket tree hierarchy. The assumption is that if this ## is done for all sibling-less buckets, they will all eventually reach a @@ -206,11 +130,6 @@ allow_stale_reads_during_cluster_state_transitions bool default=false simulated_db_pruning_latency_msec int default=0 simulated_db_merging_latency_msec int default=0 -## Whether to use a B-tree data structure for the distributor bucket database instead -## of the legacy database. Setting this option may trigger alternate code paths for -## read only operations, as the B-tree database is thread safe for concurrent reads. -use_btree_database bool default=true restart - ## If a bucket is inconsistent and an Update operation is received, a two-phase ## write-repair path is triggered in which a Get is sent to all diverging replicas. ## Once received, the update is applied on the distributor and pushed out to the @@ -252,15 +171,6 @@ enable_metadata_only_fetch_phase_for_inconsistent_updates bool default=true ## accesses when the distributor is heavily loaded with feed operations. max_consecutively_inhibited_maintenance_ticks int default=20 -## If set, pending merges to buckets in the global bucket space will be prioritized -## higher than merges to buckets in the default bucket space. This ensures that global -## documents will be kept in sync without being starved by non-global documents. -## Note that enabling this feature risks starving default bucket space merges if a -## resource exhaustion case prevents global merges from completing. -## This is a live config for that reason, i.e. it can be disabled in an emergency -## situation if needed. -prioritize_global_bucket_merges bool default=true - ## If set, activation of bucket replicas is limited to only those replicas that have ## bucket info consistent with a majority of the other replicas for that bucket. ## Multiple active replicas is only a feature that is enabled for grouped clusters, diff --git a/storage/src/vespa/storage/config/stor-opslogger.def b/storage/src/vespa/storage/config/stor-opslogger.def deleted file mode 100644 index 40124b9ff03..00000000000 --- a/storage/src/vespa/storage/config/stor-opslogger.def +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -namespace=vespa.config.content.core - -targetfile string default="" restart diff --git a/storage/src/vespa/storage/config/stor-server.def b/storage/src/vespa/storage/config/stor-server.def index 0d877d33277..26b8450ab20 100644 --- a/storage/src/vespa/storage/config/stor-server.def +++ b/storage/src/vespa/storage/config/stor-server.def @@ -21,14 +21,6 @@ is_distributor bool restart ## other nodes. node_capacity double default=1.0 restart -## Capacity of the disks on this node. How much data and load will each disk -## get relative to the other disks on this node. -disk_capacity[] double restart - -## Reliability of this node. How much of the cluster redundancy factor can this -## node make up for. -node_reliability int default=1 restart - ## The upper bound of merges that any storage node can have active. ## A merge operation will be chained through all nodes involved in the ## merge, only actually starting the operation when every node has @@ -56,7 +48,7 @@ merge_throttling_policy.window_size_increment double default=2.0 ## > 0 explicit limit in bytes ## == 0 limit automatically deduced by content node ## < 0 unlimited (legacy behavior) -merge_throttling_memory_limit.max_usage_bytes long default=-1 +merge_throttling_memory_limit.max_usage_bytes long default=0 ## If merge_throttling_memory_limit.max_usage_bytes == 0, this factor is used ## as a multiplier to automatically deduce a memory limit for merges on the @@ -109,19 +101,10 @@ enable_dead_lock_detector_warnings bool default=true ## allow for more slack before dead lock detector kicks in. The value is in seconds. dead_lock_detector_timeout_slack double default=240 -## If set to 0, storage will attempt to auto-detect the number of VDS mount -## points to use. If set to a number, force this number. This number only makes -## sense on a storage node of course -disk_count int default=0 restart - ## Configure persistence provider. Temporary here to test. persistence_provider.type enum {STORAGE, DUMMY, RPC } default=STORAGE restart persistence_provider.rpc.connectspec string default="tcp/localhost:27777" restart -## Whether or not to use the new metadata flow implementation. Default to not -## as it is currently in development and not even functional -switch_new_meta_data_flow bool default=false restart - ## When the content layer receives a set of changed buckets from the persistence ## layer, it must recheck all of these. Each such recheck results in an ## operation scheduled against the persistence queust and since the total @@ -135,10 +118,6 @@ bucket_rechecking_chunk_size int default=100 ## Only useful for testing! simulated_bucket_request_latency_msec int default=0 -## If set, content node processes will use a B-tree backed bucket database implementation -## instead of the legacy Judy-based implementation. -use_content_node_btree_bucket_db bool default=true restart - ## If non-zero, the bucket DB will be striped into 2^bits sub-databases, each handling ## a disjoint subset of the node's buckets, in order to reduce locking contention. ## Max value is unspecified, but will be clamped internally. diff --git a/storage/src/vespa/storage/distributor/idealstatemanager.cpp b/storage/src/vespa/storage/distributor/idealstatemanager.cpp index 65e036282d3..0f5d0e48f5a 100644 --- a/storage/src/vespa/storage/distributor/idealstatemanager.cpp +++ b/storage/src/vespa/storage/distributor/idealstatemanager.cpp @@ -71,11 +71,6 @@ IdealStateManager::runStateCheckers(const StateChecker::Context& c) const // We go through _all_ active state checkers so that statistics can be // collected across all checkers, not just the ones that are highest pri. for (const auto & checker : _stateCheckers) { - if (!operation_context().distributor_config().stateCheckerIsActive(checker->getName())) { - LOG(spam, "Skipping state checker %s", checker->getName()); - continue; - } - auto result = checker->check(c); if (canOverwriteResult(highestPri, result)) { highestPri = std::move(result); @@ -146,7 +141,7 @@ IdealStateManager::generateInterceptingSplit(BucketSpace bucketSpace, const Buck c.set_entry(e); IdealStateOperation::UP operation(_splitBucketStateChecker->check(c).createOperation()); - if (operation.get()) { + if (operation) { operation->setPriority(pri); operation->setIdealStateManager(this); } @@ -159,7 +154,7 @@ IdealStateManager::generate(const document::Bucket& bucket) const { NodeMaintenanceStatsTracker statsTracker; IdealStateOperation::SP op(generateHighestPriority(bucket, statsTracker).createOperation()); - if (op.get()) { + if (op) { op->setIdealStateManager(const_cast<IdealStateManager*>(this)); } return op; diff --git a/storage/src/vespa/storage/distributor/statecheckers.cpp b/storage/src/vespa/storage/distributor/statecheckers.cpp index 97641ae86a6..86790a2ddb7 100644 --- a/storage/src/vespa/storage/distributor/statecheckers.cpp +++ b/storage/src/vespa/storage/distributor/statecheckers.cpp @@ -721,7 +721,7 @@ checkForNodesMissingFromIdealState(const StateChecker::Context& c) if (c.idealState().size() > c.entry()->getNodeCount()) { ret.markMissingReplica(node, mp.mergeTooFewCopies); } else { - ret.markMoveToIdealLocation(node,mp.mergeMoveToIdealNode); + ret.markMoveToIdealLocation(node, mp.mergeMoveToIdealNode); } c.stats.incCopyingIn(node, c.getBucketSpace()); hasMissingReplica = true; @@ -807,9 +807,7 @@ SynchronizeAndMoveStateChecker::check(const Context &c) const c.distributorConfig.getMaxNodesPerMerge()); op->setDetailedReason(result.reason()); MaintenancePriority::Priority schedPri; - if ((c.getBucketSpace() == document::FixedBucketSpaces::default_space()) - || !c.distributorConfig.prioritize_global_bucket_merges()) - { + if (c.getBucketSpace() == document::FixedBucketSpaces::default_space()) { schedPri = (result.needsMoveOnly() ? MaintenancePriority::LOW : MaintenancePriority::MEDIUM); op->setPriority(result.priority()); } else { diff --git a/storage/src/vespa/storage/persistence/asynchandler.cpp b/storage/src/vespa/storage/persistence/asynchandler.cpp index a69f9e55afb..725cf2c7511 100644 --- a/storage/src/vespa/storage/persistence/asynchandler.cpp +++ b/storage/src/vespa/storage/persistence/asynchandler.cpp @@ -225,29 +225,6 @@ AsyncHandler::on_delete_bucket_complete(const document::Bucket& bucket) const { } } -MessageTracker::UP -AsyncHandler::handleDeleteBucket(api::DeleteBucketCommand& cmd, MessageTracker::UP tracker) const -{ - tracker->setMetric(_env._metrics.deleteBuckets); - LOG(debug, "DeletingBucket(%s)", cmd.getBucketId().toString().c_str()); - if (_env._fileStorHandler.isMerging(cmd.getBucket())) { - _env._fileStorHandler.clearMergeStatus(cmd.getBucket(), - api::ReturnCode(api::ReturnCode::ABORTED, "Bucket was deleted during the merge")); - } - spi::Bucket bucket(cmd.getBucket()); - if (!checkProviderBucketInfoMatches(bucket, cmd.getBucketInfo())) { - return tracker; - } - - auto task = makeResultTask([this, tracker = std::move(tracker), bucket = cmd.getBucket()]([[maybe_unused]] spi::Result::UP ignored) { - // TODO Even if an non OK response can not be handled sanely we might probably log a message, or increment a metric - on_delete_bucket_complete(bucket); - tracker->sendReply(); - }); - _spi.deleteBucketAsync(bucket, std::make_unique<ResultTaskOperationDone>(_sequencedExecutor, cmd.getBucketId(), std::move(task))); - return tracker; -} - namespace { class GatherBucketMetadata : public BucketProcessor::EntryProcessor { diff --git a/storage/src/vespa/storage/persistence/asynchandler.h b/storage/src/vespa/storage/persistence/asynchandler.h index 1433176036b..c78dfe6282d 100644 --- a/storage/src/vespa/storage/persistence/asynchandler.h +++ b/storage/src/vespa/storage/persistence/asynchandler.h @@ -31,7 +31,6 @@ public: MessageTrackerUP handleUpdate(api::UpdateCommand& cmd, MessageTrackerUP tracker) const; MessageTrackerUP handleRunTask(RunTaskCommand & cmd, MessageTrackerUP tracker) const; MessageTrackerUP handleSetBucketState(api::SetBucketStateCommand& cmd, MessageTrackerUP tracker) const; - MessageTrackerUP handleDeleteBucket(api::DeleteBucketCommand& cmd, MessageTrackerUP tracker) const; MessageTrackerUP handle_delete_bucket_throttling(api::DeleteBucketCommand& cmd, MessageTrackerUP tracker) const; MessageTrackerUP handleCreateBucket(api::CreateBucketCommand& cmd, MessageTrackerUP tracker) const; MessageTrackerUP handleRemoveLocation(api::RemoveLocationCommand& cmd, MessageTrackerUP tracker) const; diff --git a/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp b/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp index 61c7da6e286..90703050009 100644 --- a/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp +++ b/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp @@ -212,8 +212,7 @@ FileStorManager::on_configure(const StorFilestorConfig& config) _use_async_message_handling_on_schedule = config.useAsyncMessageHandlingOnSchedule; _host_info_reporter.set_noise_level(config.resourceUsageReporterNoiseLevel); - const bool use_dynamic_throttling = ((config.asyncOperationThrottlerType == StorFilestorConfig::AsyncOperationThrottlerType::DYNAMIC) || - (config.asyncOperationThrottler.type == StorFilestorConfig::AsyncOperationThrottler::Type::DYNAMIC)); + const bool use_dynamic_throttling = (config.asyncOperationThrottler.type == StorFilestorConfig::AsyncOperationThrottler::Type::DYNAMIC); const bool throttle_merge_feed_ops = config.asyncOperationThrottler.throttleIndividualMergeFeedOps; if (!liveUpdate) { @@ -248,7 +247,6 @@ FileStorManager::on_configure(const StorFilestorConfig& config) std::lock_guard guard(_lock); for (auto& ph : _persistenceHandlers) { ph->set_throttle_merge_feed_ops(throttle_merge_feed_ops); - ph->set_use_per_document_throttled_delete_bucket(config.usePerDocumentThrottledDeleteBucket); } } } diff --git a/storage/src/vespa/storage/persistence/persistencehandler.cpp b/storage/src/vespa/storage/persistence/persistencehandler.cpp index 78ad6eac0e2..29d39845f5a 100644 --- a/storage/src/vespa/storage/persistence/persistencehandler.cpp +++ b/storage/src/vespa/storage/persistence/persistencehandler.cpp @@ -24,8 +24,7 @@ PersistenceHandler::PersistenceHandler(vespalib::ISequencedTaskExecutor & sequen cfg.commonMergeChainOptimalizationMinimumSize), _asyncHandler(_env, provider, bucketOwnershipNotifier, sequencedExecutor, component.getBucketIdFactory()), _splitJoinHandler(_env, provider, bucketOwnershipNotifier, cfg.enableMultibitSplitOptimalization), - _simpleHandler(_env, provider, component.getBucketIdFactory()), - _use_per_op_throttled_delete_bucket(false) + _simpleHandler(_env, provider, component.getBucketIdFactory()) { } @@ -69,11 +68,7 @@ PersistenceHandler::handleCommandSplitByType(api::StorageCommand& msg, MessageTr case api::MessageType::CREATEBUCKET_ID: return _asyncHandler.handleCreateBucket(static_cast<api::CreateBucketCommand&>(msg), std::move(tracker)); case api::MessageType::DELETEBUCKET_ID: - if (use_per_op_throttled_delete_bucket()) { - return _asyncHandler.handle_delete_bucket_throttling(static_cast<api::DeleteBucketCommand&>(msg), std::move(tracker)); - } else { - return _asyncHandler.handleDeleteBucket(static_cast<api::DeleteBucketCommand&>(msg), std::move(tracker)); - } + return _asyncHandler.handle_delete_bucket_throttling(static_cast<api::DeleteBucketCommand&>(msg), std::move(tracker)); case api::MessageType::JOINBUCKETS_ID: return _splitJoinHandler.handleJoinBuckets(static_cast<api::JoinBucketsCommand&>(msg), std::move(tracker)); case api::MessageType::SPLITBUCKET_ID: @@ -114,7 +109,7 @@ PersistenceHandler::handleCommandSplitByType(api::StorageCommand& msg, MessageTr default: break; } - return MessageTracker::UP(); + return {}; } MessageTracker::UP @@ -186,14 +181,4 @@ PersistenceHandler::set_throttle_merge_feed_ops(bool throttle) noexcept _mergeHandler.set_throttle_merge_feed_ops(throttle); } -bool -PersistenceHandler::use_per_op_throttled_delete_bucket() const noexcept { - return _use_per_op_throttled_delete_bucket.load(std::memory_order_relaxed); -} - -void -PersistenceHandler::set_use_per_document_throttled_delete_bucket(bool throttle) noexcept { - _use_per_op_throttled_delete_bucket.store(throttle, std::memory_order_relaxed); -} - } diff --git a/storage/src/vespa/storage/persistence/persistencehandler.h b/storage/src/vespa/storage/persistence/persistencehandler.h index 9639b772a28..595815d2bb3 100644 --- a/storage/src/vespa/storage/persistence/persistencehandler.h +++ b/storage/src/vespa/storage/persistence/persistencehandler.h @@ -38,14 +38,12 @@ public: const SimpleMessageHandler & simpleMessageHandler() const { return _simpleHandler; } void set_throttle_merge_feed_ops(bool throttle) noexcept; - void set_use_per_document_throttled_delete_bucket(bool throttle) noexcept; private: // Message handling functions MessageTracker::UP handleCommandSplitByType(api::StorageCommand&, MessageTracker::UP tracker) const; MessageTracker::UP handleReply(api::StorageReply&, MessageTracker::UP) const; MessageTracker::UP processMessage(api::StorageMessage& msg, MessageTracker::UP tracker) const; - [[nodiscard]] bool use_per_op_throttled_delete_bucket() const noexcept; const framework::Clock & _clock; PersistenceUtil _env; @@ -54,7 +52,6 @@ private: AsyncHandler _asyncHandler; SplitJoinHandler _splitJoinHandler; SimpleMessageHandler _simpleHandler; - std::atomic<bool> _use_per_op_throttled_delete_bucket; }; } // storage diff --git a/storage/src/vespa/storage/storageserver/bouncer.cpp b/storage/src/vespa/storage/storageserver/bouncer.cpp index bfc38e0c8ba..0aedee6799d 100644 --- a/storage/src/vespa/storage/storageserver/bouncer.cpp +++ b/storage/src/vespa/storage/storageserver/bouncer.cpp @@ -2,16 +2,13 @@ #include "bouncer.h" #include "bouncer_metrics.h" -#include "config_logging.h" #include <vespa/storageframework/generic/clock/clock.h> #include <vespa/vdslib/state/cluster_state_bundle.h> #include <vespa/vdslib/state/clusterstate.h> #include <vespa/persistence/spi/bucket_limits.h> #include <vespa/storageapi/message/state.h> #include <vespa/storageapi/message/persistence.h> -#include <vespa/config/subscription/configuri.h> #include <vespa/config/helper/configfetcher.hpp> -#include <vespa/config/common/exceptions.h> #include <vespa/vespalib/util/stringfmt.h> #include <sstream> @@ -62,7 +59,6 @@ Bouncer::onClose() void Bouncer::on_configure(const vespa::config::content::core::StorBouncerConfig& config) { - validateConfig(config); auto new_config = std::make_unique<StorBouncerConfig>(config); std::lock_guard lock(_lock); _config = std::move(new_config); @@ -72,27 +68,6 @@ const BouncerMetrics& Bouncer::metrics() const noexcept { return *_metrics; } -void -Bouncer::validateConfig(const vespa::config::content::core::StorBouncerConfig& newConfig) const -{ - if (newConfig.feedRejectionPriorityThreshold != -1) { - if (newConfig.feedRejectionPriorityThreshold - > std::numeric_limits<api::StorageMessage::Priority>::max()) - { - throw config::InvalidConfigException( - "feed_rejection_priority_threshold config value exceeds " - "maximum allowed value"); - } - if (newConfig.feedRejectionPriorityThreshold - < std::numeric_limits<api::StorageMessage::Priority>::min()) - { - throw config::InvalidConfigException( - "feed_rejection_priority_threshold config value lower than " - "minimum allowed value"); - } - } -} - void Bouncer::append_node_identity(std::ostream& target_stream) const { target_stream << " (on " << _component.getNodeType() << '.' << _component.getIndex() << ")"; } @@ -101,8 +76,7 @@ void Bouncer::abortCommandForUnavailableNode(api::StorageMessage& msg, const lib::State& state) { // If we're not up or retired, fail due to this nodes state. - std::shared_ptr<api::StorageReply> reply( - static_cast<api::StorageCommand&>(msg).makeReply()); + std::shared_ptr<api::StorageReply> reply(static_cast<api::StorageCommand&>(msg).makeReply()); std::ostringstream ost; ost << "We don't allow command of type " << msg.getType() << " when node is in state " << state.toString(true); @@ -235,18 +209,14 @@ Bouncer::onDown(const std::shared_ptr<api::StorageMessage>& msg) const lib::State* state; int maxClockSkewInSeconds; bool isInAvailableState; - bool abortLoadWhenClusterDown; bool closed; const lib::State* cluster_state; - int feedPriorityLowerBound; { std::lock_guard lock(_lock); state = &getDerivedNodeState(msg->getBucket().getBucketSpace()).getState(); maxClockSkewInSeconds = _config->maxClockSkewSeconds; - abortLoadWhenClusterDown = _config->stopExternalLoadWhenClusterDown; cluster_state = _clusterState; - isInAvailableState = state->oneOf(_config->stopAllLoadWhenNodestateNotIn.c_str()); - feedPriorityLowerBound = _config->feedRejectionPriorityThreshold; + isInAvailableState = state->oneOf("uri"); closed = _closed; } const api::MessageType& type = msg->getType(); @@ -292,13 +262,6 @@ Bouncer::onDown(const std::shared_ptr<api::StorageMessage>& msg) if (!externalLoad) { return false; } - if (priorityRejectionIsEnabled(feedPriorityLowerBound) - && isExternalWriteOperation(type) - && (msg->getPriority() > feedPriorityLowerBound)) - { - rejectDueToInsufficientPriority(*msg, feedPriorityLowerBound); - return true; - } uint64_t timestamp = extractMutationTimestampIfAny(*msg); if (timestamp != 0) { @@ -311,7 +274,7 @@ Bouncer::onDown(const std::shared_ptr<api::StorageMessage>& msg) } // If cluster state is not up, fail external load - if (abortLoadWhenClusterDown && !clusterIsUp(*cluster_state)) { + if (!clusterIsUp(*cluster_state)) { abortCommandDueToClusterDown(*msg, *cluster_state); return true; } diff --git a/storage/src/vespa/storage/storageserver/bouncer.h b/storage/src/vespa/storage/storageserver/bouncer.h index 26282625269..51620a49bda 100644 --- a/storage/src/vespa/storage/storageserver/bouncer.h +++ b/storage/src/vespa/storage/storageserver/bouncer.h @@ -49,7 +49,6 @@ public: const BouncerMetrics& metrics() const noexcept; private: - void validateConfig(const vespa::config::content::core::StorBouncerConfig&) const; void onClose() override; void abortCommandForUnavailableNode(api::StorageMessage&, const lib::State&); void rejectCommandWithTooHighClockSkew(api::StorageMessage& msg, int maxClockSkewInSeconds); @@ -61,9 +60,7 @@ private: bool isDistributor() const; static bool isExternalLoad(const api::MessageType&) noexcept; static bool isExternalWriteOperation(const api::MessageType&) noexcept; - static bool priorityRejectionIsEnabled(int configuredPriority) noexcept { - return (configuredPriority != -1); - } + /** * If msg is a command containing a mutating timestamp (put, remove or diff --git a/storage/src/vespa/storage/visiting/stor-visitor.def b/storage/src/vespa/storage/visiting/stor-visitor.def index a8da6ee5032..752b4ce39df 100644 --- a/storage/src/vespa/storage/visiting/stor-visitor.def +++ b/storage/src/vespa/storage/visiting/stor-visitor.def @@ -5,35 +5,11 @@ namespace=vespa.config.content.core ## Keep in sync with #stor-filestor:num_visitor_threads visitorthreads int default=16 restart -## Default timeout of visitors that loses contact with client (in seconds) -disconnectedvisitortimeout int default=0 restart - -## Time period (in seconds) in which to ignore requests to visitors that doesnt -## exist anymore. (Normal for visitors to get some messages right after -## aborting, logging them as faults instead after this timeout has passed.) -ignorenonexistingvisitortimelimit int default=300 restart - ## The number of buckets that are visited in parallel in a visitor visiting ## multiple buckets. Default is 8, meaning if you send a create visitor to visit ## 100 buckets, 8 of them will be visited in parallel. defaultparalleliterators int default=8 -## Default number of maximum client replies pending. -defaultpendingmessages int default=32 - -## Default size of docblocks used to transfer visitor data. -defaultdocblocksize int default=4190208 - -## Default docblock timeout in ms used to transfer visitor data. -## Currently defaults to a day. This is to avoid slow visitor target problems, -## getting data resent faster than it can process, and since there are very few -## reasons to actually time out -defaultdocblocktimeout int default=180000 - -## Default timeout of visitor info messages: Progress and error reports. -## If these time out, the visitor will be aborted on the storage node. -defaultinfotimeout int default=60000 - ## Max concurrent visitors (legacy) maxconcurrentvisitors int default=64 diff --git a/storage/src/vespa/storage/visiting/visitor.cpp b/storage/src/vespa/storage/visiting/visitor.cpp index 6904ecd1450..ceb356a982c 100644 --- a/storage/src/vespa/storage/visiting/visitor.cpp +++ b/storage/src/vespa/storage/visiting/visitor.cpp @@ -192,6 +192,8 @@ Visitor::Visitor(StorageComponent& component) _hitCounter(), _trace(DEFAULT_TRACE_MEMORY_LIMIT), _messageHandler(nullptr), + _messageSession(), + _documentPriority(documentapi::Priority::PRI_NORMAL_3), _id(), _controlDestination(), _dataDestination(), @@ -207,7 +209,7 @@ Visitor::~Visitor() void Visitor::sendMessage(documentapi::DocumentMessage::UP cmd) { - assert(cmd.get()); + assert(cmd); if (!isRunning()) return; cmd->setRoute(*_dataDestination); @@ -261,10 +263,10 @@ Visitor::sendDocumentApiMessage(VisitorTarget::MessageMeta& msgMeta) { void Visitor::sendInfoMessage(documentapi::VisitorInfoMessage::UP cmd) { - assert(cmd.get()); + assert(cmd); if (!isRunning()) return; - if (_controlDestination->toString().length()) { + if (!_controlDestination->toString().empty()) { cmd->setRoute(*_controlDestination); cmd->setPriority(_documentPriority); cmd->setTimeRemaining(_visitorInfoTimeout); @@ -334,7 +336,7 @@ Visitor::forceClose() void Visitor::sendReplyOnce() { - assert(_initiatingCmd.get()); + assert(_initiatingCmd); if (!_hasSentReply) { std::shared_ptr<api::StorageReply> reply(_initiatingCmd->makeReply()); @@ -521,7 +523,7 @@ Visitor::attach(std::shared_ptr<api::CreateVisitorCommand> initiatingCmd, { _priority = initiatingCmd->getPriority(); _timeToDie = capped_future(_component.getClock().getMonotonicTime(), timeout); - if (_initiatingCmd.get()) { + if (_initiatingCmd) { std::shared_ptr<api::StorageReply> reply(_initiatingCmd->makeReply()); reply->setResult(api::ReturnCode::ABORTED); _messageHandler->send(reply); @@ -923,7 +925,7 @@ Visitor::getStatus(std::ostream& out, bool verbose) const } out << "</td></tr>\n"; out << "<tr><td>Document selection</td><td>"; - if (_documentSelection.get()) { + if (_documentSelection) { out << xml_content_escaped(_documentSelection->toString()); } else { out << "nil"; @@ -990,7 +992,7 @@ Visitor::getStatus(std::ostream& out, bool verbose) const != _visitorTarget._pendingMessages.end()) { out << "<i>pending</i>"; - }; + } auto queued = idToSendTime.find(idAndMeta.first); if (queued != idToSendTime.end()) { out << "Scheduled for sending at timestamp " diff --git a/storage/src/vespa/storage/visiting/visitormanager.cpp b/storage/src/vespa/storage/visiting/visitormanager.cpp index dc1635bc4b1..55948fb47cb 100644 --- a/storage/src/vespa/storage/visiting/visitormanager.cpp +++ b/storage/src/vespa/storage/visiting/visitormanager.cpp @@ -8,7 +8,6 @@ #include "recoveryvisitor.h" #include "reindexing_visitor.h" #include <vespa/storageframework/generic/thread/thread.h> -#include <vespa/config/subscription/configuri.h> #include <vespa/config/common/exceptions.h> #include <vespa/config/helper/configfetcher.hpp> #include <vespa/storage/common/statusmessages.h> @@ -117,11 +116,6 @@ void VisitorManager::on_configure(const vespa::config::content::core::StorVisitorConfig& config) { std::lock_guard sync(_visitorLock); - if (config.defaultdocblocksize % 512 != 0) { - throw config::InvalidConfigException( - "The default docblock size needs to be a multiple of the " - "disk block size. (512b)"); - } // Do some sanity checking of input. Cannot haphazardly mix and match // old and new max concurrency config values diff --git a/storage/src/vespa/storage/visiting/visitorthread.cpp b/storage/src/vespa/storage/visiting/visitorthread.cpp index 0de954c47ed..57198f6761a 100644 --- a/storage/src/vespa/storage/visiting/visitorthread.cpp +++ b/storage/src/vespa/storage/visiting/visitorthread.cpp @@ -21,6 +21,9 @@ using storage::api::ReturnCode; namespace storage { +constexpr uint32_t DEFAULT_PENDING_MESSAGES = 32; +constexpr uint32_t DEFAULT_DOCBLOCK_SIZE = 4_Mi; + VisitorThread::Event::Event(Event&& other) noexcept : _visitorId(other._visitorId), _message(std::move(other._message)), @@ -82,15 +85,9 @@ VisitorThread::VisitorThread(uint32_t threadIndex, _messageSender(sender), _metrics(metrics), _threadIndex(threadIndex), - _disconnectedVisitorTimeout(0), // Need config to set values - _ignoreNonExistingVisitorTimeLimit(0), _defaultParallelIterators(0), _iteratorsPerBucket(1), - _defaultPendingMessages(0), - _defaultDocBlockSize(0), _visitorMemoryUsageLimit(UINT32_MAX), - _defaultDocBlockTimeout(180s), - _defaultVisitorInfoTimeout(60s), _timeBetweenTicks(1000), _component(componentRegister, getThreadName(threadIndex)), _messageSessionFactory(messageSessionFac), @@ -402,11 +399,9 @@ validateDocumentSelection(const document::DocumentTypeRepo& repo, } bool -VisitorThread::onCreateVisitor( - const std::shared_ptr<api::CreateVisitorCommand>& cmd) +VisitorThread::onCreateVisitor(const std::shared_ptr<api::CreateVisitorCommand>& cmd) { metrics::MetricTimer visitorTimer; - assert(_defaultDocBlockSize); // Ensure we've gotten a config assert(_currentlyRunningVisitor == _visitors.end()); ReturnCode result(ReturnCode::OK); std::unique_ptr<document::select::Node> docSelection; @@ -437,7 +432,7 @@ VisitorThread::onCreateVisitor( if (cmd->getMaximumPendingReplyCount() != 0) { visitor->setMaxPending(cmd->getMaximumPendingReplyCount()); } else { - visitor->setMaxPending(_defaultPendingMessages); + visitor->setMaxPending(DEFAULT_PENDING_MESSAGES); } visitor->setFieldSet(cmd->getFieldSet()); @@ -449,11 +444,11 @@ VisitorThread::onCreateVisitor( visitor->setMaxParallel(_defaultParallelIterators); visitor->setMaxParallelPerBucket(_iteratorsPerBucket); - visitor->setDocBlockSize(_defaultDocBlockSize); + visitor->setDocBlockSize(DEFAULT_DOCBLOCK_SIZE); visitor->setMemoryUsageLimit(_visitorMemoryUsageLimit); - visitor->setDocBlockTimeout(_defaultDocBlockTimeout); - visitor->setVisitorInfoTimeout(_defaultVisitorInfoTimeout); + visitor->setDocBlockTimeout(180s); + visitor->setVisitorInfoTimeout(60s); visitor->setOwnNodeIndex(_component.getIndex()); visitor->setBucketSpace(cmd->getBucketSpace()); @@ -546,70 +541,26 @@ VisitorThread::onInternal(const std::shared_ptr<api::InternalCommand>& cmd) { auto& pcmd = dynamic_cast<PropagateVisitorConfig&>(*cmd); const vespa::config::content::core::StorVisitorConfig& config(pcmd.getConfig()); - if (_defaultDocBlockSize != 0) { // Live update - LOG(config, "Updating visitor thread configuration in visitor " - "thread %u: " - "Current config(disconnectedVisitorTimeout %u," - " ignoreNonExistingVisitorTimeLimit %u," - " defaultParallelIterators %u," - " iteratorsPerBucket %u," - " defaultPendingMessages %u," - " defaultDocBlockSize %u," - " visitorMemoryUsageLimit %u," - " defaultDocBlockTimeout %" PRIu64 "," - " defaultVisitorInfoTimeout %" PRIu64 ") " - "New config(disconnectedVisitorTimeout %u," - " ignoreNonExistingVisitorTimeLimit %u," - " defaultParallelIterators %u," - " defaultPendingMessages %u," - " defaultDocBlockSize %u," - " visitorMemoryUsageLimit %u," - " defaultDocBlockTimeout %u," - " defaultVisitorInfoTimeout %u) ", - _threadIndex, - _disconnectedVisitorTimeout, - _ignoreNonExistingVisitorTimeLimit, - _defaultParallelIterators, - _iteratorsPerBucket, - _defaultPendingMessages, - _defaultDocBlockSize, - _visitorMemoryUsageLimit, - vespalib::count_ms(_defaultDocBlockTimeout), - vespalib::count_ms(_defaultVisitorInfoTimeout), - config.disconnectedvisitortimeout, - config.ignorenonexistingvisitortimelimit, - config.defaultparalleliterators, - config.defaultpendingmessages, - config.defaultdocblocksize, - config.visitorMemoryUsageLimit, - config.defaultdocblocktimeout, - config.defaultinfotimeout - ); - } - _disconnectedVisitorTimeout = config.disconnectedvisitortimeout; - _ignoreNonExistingVisitorTimeLimit = config.ignorenonexistingvisitortimelimit; + LOG(config, "Updating visitor thread configuration in visitor " + "thread %u: " + "Current config(defaultParallelIterators %u," + " iteratorsPerBucket %u," + " visitorMemoryUsageLimit %u)" + "New config(defaultParallelIterators %u," + " visitorMemoryUsageLimit %u)", + _threadIndex, + _defaultParallelIterators, + _iteratorsPerBucket, + _visitorMemoryUsageLimit, + config.defaultparalleliterators, + config.visitorMemoryUsageLimit + ); _defaultParallelIterators = config.defaultparalleliterators; - _defaultPendingMessages = config.defaultpendingmessages; - _defaultDocBlockSize = config.defaultdocblocksize; _visitorMemoryUsageLimit = config.visitorMemoryUsageLimit; - _defaultDocBlockTimeout = std::chrono::milliseconds(config.defaultdocblocktimeout); - _defaultVisitorInfoTimeout = std::chrono::milliseconds(config.defaultinfotimeout); if (_defaultParallelIterators < 1) { LOG(config, "Cannot use value of defaultParallelIterators < 1"); _defaultParallelIterators = 1; } - if (_defaultPendingMessages < 1) { - LOG(config, "Cannot use value of defaultPendingMessages < 1"); - _defaultPendingMessages = 1; - } - if (_defaultDocBlockSize < 1024) { - LOG(config, "Refusing to use default block size less than 1k"); - _defaultDocBlockSize = 1024; - } - if (_defaultDocBlockTimeout < 1ms) { - LOG(config, "Cannot use value of defaultDocBlockTimeout < 1"); - _defaultDocBlockTimeout = 1ms; - } break; } case RequestStatusPage::ID: @@ -695,20 +646,10 @@ VisitorThread::getStatus(vespalib::asciistream& out, out << "<h3>Current queue size: " << _queue.size() << "</h3>\n"; out << "<h3>Config:</h3>\n" << "<table border=\"1\"><tr><td>Parameter</td><td>Value</td></tr>\n" - << "<tr><td>Disconnected visitor timeout</td><td>" - << _disconnectedVisitorTimeout << "</td></tr>\n" - << "<tr><td>Ignore non-existing visitor timelimit</td><td>" - << _ignoreNonExistingVisitorTimeLimit << "</td></tr>\n" << "<tr><td>Default parallel iterators</td><td>" << _defaultParallelIterators << "</td></tr>\n" << "<tr><td>Iterators per bucket</td><td>" << _iteratorsPerBucket << "</td></tr>\n" - << "<tr><td>Default pending messages</td><td>" - << _defaultPendingMessages << "</td></tr>\n" - << "<tr><td>Default DocBlock size</td><td>" - << _defaultDocBlockSize << "</td></tr>\n" - << "<tr><td>Default DocBlock timeout (ms)</td><td>" - << vespalib::count_ms(_defaultDocBlockTimeout) << "</td></tr>\n" << "<tr><td>Visitor memory usage limit</td><td>" << _visitorMemoryUsageLimit << "</td></tr>\n" << "</table>\n"; diff --git a/storage/src/vespa/storage/visiting/visitorthread.h b/storage/src/vespa/storage/visiting/visitorthread.h index 4463a62fdd9..034c726e947 100644 --- a/storage/src/vespa/storage/visiting/visitorthread.h +++ b/storage/src/vespa/storage/visiting/visitorthread.h @@ -75,15 +75,9 @@ class VisitorThread : public framework::Runnable, VisitorMessageHandler& _messageSender; VisitorThreadMetrics& _metrics; uint32_t _threadIndex; - uint32_t _disconnectedVisitorTimeout; - uint32_t _ignoreNonExistingVisitorTimeLimit; uint32_t _defaultParallelIterators; uint32_t _iteratorsPerBucket; - uint32_t _defaultPendingMessages; - uint32_t _defaultDocBlockSize; uint32_t _visitorMemoryUsageLimit; - vespalib::duration _defaultDocBlockTimeout; - vespalib::duration _defaultVisitorInfoTimeout; std::atomic<uint32_t> _timeBetweenTicks; StorageComponent _component; std::unique_ptr<framework::Thread> _thread; diff --git a/storageserver/src/tests/testhelper.cpp b/storageserver/src/tests/testhelper.cpp index 6877ed9aba6..40e263d4e68 100644 --- a/storageserver/src/tests/testhelper.cpp +++ b/storageserver/src/tests/testhelper.cpp @@ -1,6 +1,5 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <tests/testhelper.h> -#include <vespa/vespalib/io/fileutil.h> #include <vespa/log/log.h> LOG_SETUP(".testhelper"); @@ -41,10 +40,7 @@ vdstestlib::DirConfig getStandardConfig(bool storagenode) { config = &dc.addConfig("stor-communicationmanager"); config->set("rpcport", "0"); config->set("mbusport", "0"); - config = &dc.addConfig("stor-bucketdb"); - config->set("chunklevel", "0"); config = &dc.addConfig("stor-distributormanager"); - config = &dc.addConfig("stor-opslogger"); config = &dc.addConfig("stor-filestor"); // Easier to see what goes wrong with only 1 thread per disk. config->set("threads[1]"); @@ -90,8 +86,7 @@ vdstestlib::DirConfig getStandardConfig(bool storagenode) { return dc; } -void addSlobrokConfig(vdstestlib::DirConfig& dc, - const mbus::Slobrok& slobrok) +void addSlobrokConfig(vdstestlib::DirConfig& dc, const mbus::Slobrok& slobrok) { std::ostringstream ost; ost << "tcp/localhost:" << slobrok.port(); diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp index 93e35e4c6d2..c9518b29884 100644 --- a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp +++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp @@ -84,7 +84,7 @@ RankProcessorTest::test_unpack_match_data_for_term_node(bool interleaved_feature EXPECT_EQ(invalid_id, tfmd->getDocId()); RankProcessor::unpack_match_data(1, *md, *_query_wrapper); EXPECT_EQ(invalid_id, tfmd->getDocId()); - node->add(0, field_id, 0, 1); + node->add(field_id, 0, 1, 0); auto& field_info = node->getFieldInfo(field_id); field_info.setHitCount(mock_num_occs); field_info.setFieldLength(mock_field_length); diff --git a/streamingvisitors/src/tests/searcher/searcher_test.cpp b/streamingvisitors/src/tests/searcher/searcher_test.cpp index 24877866c1b..daa26b855e8 100644 --- a/streamingvisitors/src/tests/searcher/searcher_test.cpp +++ b/streamingvisitors/src/tests/searcher/searcher_test.cpp @@ -3,6 +3,7 @@ #include <vespa/vespalib/testkit/testapp.h> #include <vespa/document/fieldvalue/fieldvalues.h> +#include <vespa/searchlib/query/streaming/fuzzy_term.h> #include <vespa/searchlib/query/streaming/regexp_term.h> #include <vespa/searchlib/query/streaming/queryterm.h> #include <vespa/vsm/searcher/boolfieldsearcher.h> @@ -18,10 +19,15 @@ #include <vespa/vsm/searcher/utf8suffixstringfieldsearcher.h> #include <vespa/vsm/searcher/tokenizereader.h> #include <vespa/vsm/vsm/snippetmodifier.h> +#include <concepts> +#include <charconv> +#include <stdexcept> +#include <utility> using namespace document; using search::streaming::HitList; using search::streaming::QueryNodeResultFactory; +using search::streaming::FuzzyTerm; using search::streaming::RegexpTerm; using search::streaming::QueryTerm; using search::streaming::Normalizing; @@ -38,7 +44,7 @@ public: Vector<T> & add(T v) { this->push_back(v); return *this; } }; -using Hits = Vector<size_t>; +using Hits = Vector<std::pair<uint32_t, uint32_t>>; using StringList = Vector<std::string> ; using HitsList = Vector<Hits>; using BoolList = Vector<bool>; @@ -58,6 +64,46 @@ public: } }; +namespace { + +template <std::integral T> +std::string_view maybe_consume_into(std::string_view str, T& val_out) { + auto [ptr, ec] = std::from_chars(str.data(), str.data() + str.size(), val_out); + if (ec != std::errc()) { + return str; + } + return str.substr(ptr - str.data()); +} + +// Parse optional max edits and prefix lock length from term string. +// Syntax: +// "term" -> {2, 0, "term"} (default max edits & prefix length) +// "{1}term" -> {1, 0, "term"} +// "{1,3}term" -> {1, 3, "term"} +// +// Note: this is not a "proper" parser (it accepts empty numeric values); only for testing! +std::tuple<uint8_t, uint32_t, std::string_view> parse_fuzzy_params(std::string_view term) { + if (term.empty() || term[0] != '{') { + return {2, 0, term}; + } + uint8_t max_edits = 2; + uint32_t prefix_length = 0; + term = maybe_consume_into(term.substr(1), max_edits); + if (term.empty() || (term[0] != ',' && term[0] != '}')) { + throw std::invalid_argument("malformed fuzzy params at (or after) max_edits"); + } + if (term[0] == '}') { + return {max_edits, prefix_length, term.substr(1)}; + } + term = maybe_consume_into(term.substr(1), prefix_length); + if (term.empty() || term[0] != '}') { + throw std::invalid_argument("malformed fuzzy params at (or after) prefix_length"); + } + return {max_edits, prefix_length, term.substr(1)}; +} + +} + class Query { private: @@ -66,10 +112,14 @@ private: ParsedQueryTerm pqt = parseQueryTerm(term); ParsedTerm pt = parseTerm(pqt.second); std::string effective_index = pqt.first.empty() ? "index" : pqt.first; - if (pt.second != TermType::REGEXP) { - qtv.push_back(std::make_unique<QueryTerm>(eqnr.create(), pt.first, effective_index, pt.second, normalizing)); + if (pt.second == TermType::REGEXP) { + qtv.push_back(std::make_unique<RegexpTerm>(eqnr.create(), pt.first, effective_index, TermType::REGEXP, normalizing)); + } else if (pt.second == TermType::FUZZYTERM) { + auto [max_edits, prefix_length, actual_term] = parse_fuzzy_params(pt.first); + qtv.push_back(std::make_unique<FuzzyTerm>(eqnr.create(), vespalib::stringref(actual_term.data(), actual_term.size()), + effective_index, TermType::FUZZYTERM, normalizing, max_edits, prefix_length)); } else { - qtv.push_back(std::make_unique<RegexpTerm>(eqnr.create(), pt.first, effective_index, pt.second, normalizing)); + qtv.push_back(std::make_unique<QueryTerm>(eqnr.create(), pt.first, effective_index, pt.second, normalizing)); } } for (const auto & i : qtv) { @@ -100,6 +150,8 @@ public: return std::make_pair(term.substr(1, term.size() - 1), TermType::SUFFIXTERM); } else if (term[0] == '#') { // magic regex enabler return std::make_pair(term.substr(1), TermType::REGEXP); + } else if (term[0] == '%') { // equally magic fuzzy enabler + return std::make_pair(term.substr(1), TermType::FUZZYTERM); } else if (term[term.size() - 1] == '*') { return std::make_pair(term.substr(0, term.size() - 1), TermType::PREFIXTERM); } else { @@ -314,7 +366,7 @@ assertNumeric(FieldSearcher & fs, const StringList & query, const FieldValue & f { HitsList hl; for (bool v : exp) { - hl.push_back(v ? Hits().add(0) : Hits()); + hl.push_back(v ? Hits().add({0, 0}) : Hits()); } assertSearch(fs, query, fv, hl); } @@ -349,7 +401,9 @@ assertSearch(FieldSearcher & fs, const StringList & query, const FieldValue & fv EXPECT_EQUAL(hl.size(), exp[i].size()); ASSERT_TRUE(hl.size() == exp[i].size()); for (size_t j = 0; j < hl.size(); ++j) { - EXPECT_EQUAL((size_t)hl[j].pos(), exp[i][j]); + EXPECT_EQUAL(0u, hl[j].field_id()); + EXPECT_EQUAL((size_t)hl[j].element_id(), exp[i][j].first); + EXPECT_EQUAL((size_t)hl[j].position(), exp[i][j].second); } } } @@ -414,9 +468,9 @@ bool assertCountWords(size_t numWords, const std::string & field) bool testStringFieldInfo(StrChrFieldSearcher & fs) { - assertString(fs, "foo", StringList().add("foo bar baz").add("foo bar").add("baz foo"), Hits().add(0).add(3).add(6)); + assertString(fs, "foo", StringList().add("foo bar baz").add("foo bar").add("baz foo"), Hits().add({0, 0}).add({1, 0}).add({2, 1})); assertString(fs, StringList().add("foo").add("bar"), StringList().add("foo bar baz").add("foo bar").add("baz foo"), - HitsList().add(Hits().add(0).add(3).add(6)).add(Hits().add(1).add(4))); + HitsList().add(Hits().add({0, 0}).add({1, 0}).add({2, 1})).add(Hits().add({0, 1}).add({1, 1}))); bool retval = true; if (!EXPECT_TRUE(assertFieldInfo(fs, "foo", "foo", QTFieldInfo(0, 1, 1)))) retval = false; @@ -445,22 +499,22 @@ testStrChrFieldSearcher(StrChrFieldSearcher & fs) std::string field = "operators and operator overloading with utf8 char oe = \xc3\x98"; assertString(fs, "oper", field, Hits()); assertString(fs, "tor", field, Hits()); - assertString(fs, "oper*", field, Hits().add(0).add(2)); - assertString(fs, "and", field, Hits().add(1)); + assertString(fs, "oper*", field, Hits().add({0, 0}).add({0, 2})); + assertString(fs, "and", field, Hits().add({0, 1})); assertString(fs, StringList().add("oper").add("tor"), field, HitsList().add(Hits()).add(Hits())); - assertString(fs, StringList().add("and").add("overloading"), field, HitsList().add(Hits().add(1)).add(Hits().add(3))); + assertString(fs, StringList().add("and").add("overloading"), field, HitsList().add(Hits().add({0, 1})).add(Hits().add({0, 3}))); fs.match_type(FieldSearcher::PREFIX); - assertString(fs, "oper", field, Hits().add(0).add(2)); - assertString(fs, StringList().add("oper").add("tor"), field, HitsList().add(Hits().add(0).add(2)).add(Hits())); + assertString(fs, "oper", field, Hits().add({0, 0}).add({0, 2})); + assertString(fs, StringList().add("oper").add("tor"), field, HitsList().add(Hits().add({0, 0}).add({0, 2})).add(Hits())); fs.match_type(FieldSearcher::REGULAR); if (!EXPECT_TRUE(testStringFieldInfo(fs))) return false; { // test handling of several underscores StringList query = StringList().add("foo").add("bar"); - HitsList exp = HitsList().add(Hits().add(0)).add(Hits().add(1)); + HitsList exp = HitsList().add(Hits().add({0, 0})).add(Hits().add({0, 1})); assertString(fs, query, "foo_bar", exp); assertString(fs, query, "foo__bar", exp); assertString(fs, query, "foo___bar", exp); @@ -470,38 +524,61 @@ testStrChrFieldSearcher(StrChrFieldSearcher & fs) query = StringList().add("foo").add("thisisaveryveryverylongword"); assertString(fs, query, "foo____________________thisisaveryveryverylongword", exp); - assertString(fs, "bar", "foo bar", Hits().add(1)); - assertString(fs, "bar", "foo____________________bar", Hits().add(1)); - assertString(fs, "bar", "foo____________________thisisaveryveryverylongword____________________bar", Hits().add(2)); + assertString(fs, "bar", "foo bar", Hits().add({0, 1})); + assertString(fs, "bar", "foo____________________bar", Hits().add({0, 1})); + assertString(fs, "bar", "foo____________________thisisaveryveryverylongword____________________bar", Hits().add({0, 2})); } return true; } - TEST("verify correct term parsing") { - ASSERT_TRUE(Query::parseQueryTerm("index:term").first == "index"); - ASSERT_TRUE(Query::parseQueryTerm("index:term").second == "term"); - ASSERT_TRUE(Query::parseQueryTerm("term").first.empty()); - ASSERT_TRUE(Query::parseQueryTerm("term").second == "term"); - ASSERT_TRUE(Query::parseTerm("*substr*").first == "substr"); - ASSERT_TRUE(Query::parseTerm("*substr*").second == TermType::SUBSTRINGTERM); - ASSERT_TRUE(Query::parseTerm("*suffix").first == "suffix"); - ASSERT_TRUE(Query::parseTerm("*suffix").second == TermType::SUFFIXTERM); - ASSERT_TRUE(Query::parseTerm("prefix*").first == "prefix"); - ASSERT_TRUE(Query::parseTerm("prefix*").second == TermType::PREFIXTERM); - ASSERT_TRUE(Query::parseTerm("#regex").first == "regex"); - ASSERT_TRUE(Query::parseTerm("#regex").second == TermType::REGEXP); - ASSERT_TRUE(Query::parseTerm("term").first == "term"); - ASSERT_TRUE(Query::parseTerm("term").second == TermType::WORD); - } - - TEST("suffix matching") { - EXPECT_EQUAL(assertMatchTermSuffix("a", "vespa"), true); - EXPECT_EQUAL(assertMatchTermSuffix("spa", "vespa"), true); - EXPECT_EQUAL(assertMatchTermSuffix("vespa", "vespa"), true); - EXPECT_EQUAL(assertMatchTermSuffix("vvespa", "vespa"), false); - EXPECT_EQUAL(assertMatchTermSuffix("fspa", "vespa"), false); - EXPECT_EQUAL(assertMatchTermSuffix("v", "vespa"), false); - } +TEST("parsing of test-only fuzzy term params can extract numeric values") { + uint8_t max_edits = 0; + uint32_t prefix_length = 1234; + std::string_view out; + + std::tie(max_edits, prefix_length, out) = parse_fuzzy_params("myterm"); + EXPECT_EQUAL(max_edits, 2u); + EXPECT_EQUAL(prefix_length, 0u); + EXPECT_EQUAL(out, "myterm"); + + std::tie(max_edits, prefix_length, out) = parse_fuzzy_params("{3}myterm"); + EXPECT_EQUAL(max_edits, 3u); + EXPECT_EQUAL(prefix_length, 0u); + EXPECT_EQUAL(out, "myterm"); + + std::tie(max_edits, prefix_length, out) = parse_fuzzy_params("{2,70}myterm"); + EXPECT_EQUAL(max_edits, 2u); + EXPECT_EQUAL(prefix_length, 70u); + EXPECT_EQUAL(out, "myterm"); +} + +TEST("verify correct term parsing") { + ASSERT_TRUE(Query::parseQueryTerm("index:term").first == "index"); + ASSERT_TRUE(Query::parseQueryTerm("index:term").second == "term"); + ASSERT_TRUE(Query::parseQueryTerm("term").first.empty()); + ASSERT_TRUE(Query::parseQueryTerm("term").second == "term"); + ASSERT_TRUE(Query::parseTerm("*substr*").first == "substr"); + ASSERT_TRUE(Query::parseTerm("*substr*").second == TermType::SUBSTRINGTERM); + ASSERT_TRUE(Query::parseTerm("*suffix").first == "suffix"); + ASSERT_TRUE(Query::parseTerm("*suffix").second == TermType::SUFFIXTERM); + ASSERT_TRUE(Query::parseTerm("prefix*").first == "prefix"); + ASSERT_TRUE(Query::parseTerm("prefix*").second == TermType::PREFIXTERM); + ASSERT_TRUE(Query::parseTerm("#regex").first == "regex"); + ASSERT_TRUE(Query::parseTerm("#regex").second == TermType::REGEXP); + ASSERT_TRUE(Query::parseTerm("%fuzzy").first == "fuzzy"); + ASSERT_TRUE(Query::parseTerm("%fuzzy").second == TermType::FUZZYTERM); + ASSERT_TRUE(Query::parseTerm("term").first == "term"); + ASSERT_TRUE(Query::parseTerm("term").second == TermType::WORD); +} + +TEST("suffix matching") { + EXPECT_EQUAL(assertMatchTermSuffix("a", "vespa"), true); + EXPECT_EQUAL(assertMatchTermSuffix("spa", "vespa"), true); + EXPECT_EQUAL(assertMatchTermSuffix("vespa", "vespa"), true); + EXPECT_EQUAL(assertMatchTermSuffix("vvespa", "vespa"), false); + EXPECT_EQUAL(assertMatchTermSuffix("fspa", "vespa"), false); + EXPECT_EQUAL(assertMatchTermSuffix("v", "vespa"), false); +} TEST("Test basic strchrfield searchers") { { @@ -519,16 +596,16 @@ testUTF8SubStringFieldSearcher(StrChrFieldSearcher & fs) { std::string field = "operators and operator overloading"; assertString(fs, "rsand", field, Hits()); - assertString(fs, "ove", field, Hits().add(3)); - assertString(fs, "ing", field, Hits().add(3)); - assertString(fs, "era", field, Hits().add(0).add(2)); - assertString(fs, "a", field, Hits().add(0).add(1).add(2).add(3)); + assertString(fs, "ove", field, Hits().add({0, 3})); + assertString(fs, "ing", field, Hits().add({0, 3})); + assertString(fs, "era", field, Hits().add({0, 0}).add({0, 2})); + assertString(fs, "a", field, Hits().add({0, 0}).add({0, 1}).add({0, 2}).add({0, 3})); assertString(fs, StringList().add("dn").add("gn"), field, HitsList().add(Hits()).add(Hits())); - assertString(fs, StringList().add("ato").add("load"), field, HitsList().add(Hits().add(0).add(2)).add(Hits().add(3))); + assertString(fs, StringList().add("ato").add("load"), field, HitsList().add(Hits().add({0, 0}).add({0, 2})).add(Hits().add({0, 3}))); assertString(fs, StringList().add("aa").add("ab"), "aaaab", - HitsList().add(Hits().add(0).add(0).add(0)).add(Hits().add(0))); + HitsList().add(Hits().add({0, 0}).add({0, 0}).add({0, 0})).add(Hits().add({0, 0}))); if (!EXPECT_TRUE(testStringFieldInfo(fs))) return false; return true; @@ -538,20 +615,20 @@ TEST("utf8 substring search") { { UTF8SubStringFieldSearcher fs(0); EXPECT_TRUE(testUTF8SubStringFieldSearcher(fs)); - assertString(fs, "aa", "aaaa", Hits().add(0).add(0)); + assertString(fs, "aa", "aaaa", Hits().add({0, 0}).add({0, 0})); } { UTF8SubStringFieldSearcher fs(0); EXPECT_TRUE(testUTF8SubStringFieldSearcher(fs)); - assertString(fs, "abc", "abc bcd abc", Hits().add(0).add(2)); + assertString(fs, "abc", "abc bcd abc", Hits().add({0, 0}).add({0, 2})); fs.maxFieldLength(4); - assertString(fs, "abc", "abc bcd abc", Hits().add(0)); + assertString(fs, "abc", "abc bcd abc", Hits().add({0, 0})); } { UTF8SubstringSnippetModifier fs(0); EXPECT_TRUE(testUTF8SubStringFieldSearcher(fs)); // we don't have 1 term optimization - assertString(fs, "aa", "aaaa", Hits().add(0).add(0).add(0)); + assertString(fs, "aa", "aaaa", Hits().add({0, 0}).add({0, 0}).add({0, 0})); } } @@ -567,11 +644,11 @@ TEST("utf8 suffix search") { UTF8SuffixStringFieldSearcher fs(0); std::string field = "operators and operator overloading"; TEST_DO(assertString(fs, "rsand", field, Hits())); - TEST_DO(assertString(fs, "tor", field, Hits().add(2))); - TEST_DO(assertString(fs, "tors", field, Hits().add(0))); + TEST_DO(assertString(fs, "tor", field, Hits().add({0, 2}))); + TEST_DO(assertString(fs, "tors", field, Hits().add({0, 0}))); TEST_DO(assertString(fs, StringList().add("an").add("din"), field, HitsList().add(Hits()).add(Hits()))); - TEST_DO(assertString(fs, StringList().add("nd").add("g"), field, HitsList().add(Hits().add(1)).add(Hits().add(3)))); + TEST_DO(assertString(fs, StringList().add("nd").add("g"), field, HitsList().add(Hits().add({0, 1})).add(Hits().add({0, 3})))); EXPECT_TRUE(testStringFieldInfo(fs)); } @@ -579,14 +656,14 @@ TEST("utf8 suffix search") { TEST("utf8 exact match") { UTF8ExactStringFieldSearcher fs(0); // regular - TEST_DO(assertString(fs, "vespa", "vespa", Hits().add(0))); + TEST_DO(assertString(fs, "vespa", "vespa", Hits().add({0, 0}))); TEST_DO(assertString(fs, "vespar", "vespa", Hits())); TEST_DO(assertString(fs, "vespa", "vespar", Hits())); TEST_DO(assertString(fs, "vespa", "vespa vespa", Hits())); TEST_DO(assertString(fs, "vesp", "vespa", Hits())); - TEST_DO(assertString(fs, "vesp*", "vespa", Hits().add(0))); - TEST_DO(assertString(fs, "hutte", "hutte", Hits().add(0))); - TEST_DO(assertString(fs, "hütte", "hütte", Hits().add(0))); + TEST_DO(assertString(fs, "vesp*", "vespa", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "hutte", "hutte", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "hütte", "hütte", Hits().add({0, 0}))); TEST_DO(assertString(fs, "hutte", "hütte", Hits())); TEST_DO(assertString(fs, "hütte", "hutte", Hits())); TEST_DO(assertString(fs, "hütter", "hütte", Hits())); @@ -596,27 +673,27 @@ TEST("utf8 exact match") { TEST("utf8 flexible searcher (except regex)"){ UTF8FlexibleStringFieldSearcher fs(0); // regular - assertString(fs, "vespa", "vespa", Hits().add(0)); + assertString(fs, "vespa", "vespa", Hits().add({0, 0})); assertString(fs, "vesp", "vespa", Hits()); assertString(fs, "esp", "vespa", Hits()); assertString(fs, "espa", "vespa", Hits()); // prefix - assertString(fs, "vesp*", "vespa", Hits().add(0)); + assertString(fs, "vesp*", "vespa", Hits().add({0, 0})); fs.match_type(FieldSearcher::PREFIX); - assertString(fs, "vesp", "vespa", Hits().add(0)); + assertString(fs, "vesp", "vespa", Hits().add({0, 0})); // substring fs.match_type(FieldSearcher::REGULAR); - assertString(fs, "*esp*", "vespa", Hits().add(0)); + assertString(fs, "*esp*", "vespa", Hits().add({0, 0})); fs.match_type(FieldSearcher::SUBSTRING); - assertString(fs, "esp", "vespa", Hits().add(0)); + assertString(fs, "esp", "vespa", Hits().add({0, 0})); // suffix fs.match_type(FieldSearcher::REGULAR); - assertString(fs, "*espa", "vespa", Hits().add(0)); + assertString(fs, "*espa", "vespa", Hits().add({0, 0})); fs.match_type(FieldSearcher::SUFFIX); - assertString(fs, "espa", "vespa", Hits().add(0)); + assertString(fs, "espa", "vespa", Hits().add({0, 0})); fs.match_type(FieldSearcher::REGULAR); EXPECT_TRUE(testStringFieldInfo(fs)); @@ -625,11 +702,11 @@ TEST("utf8 flexible searcher (except regex)"){ TEST("utf8 flexible searcher handles regex and by default has case-insensitive partial match semantics") { UTF8FlexibleStringFieldSearcher fs(0); // Note: the # term prefix is a magic term-as-regex symbol used only for tests in this file - TEST_DO(assertString(fs, "#abc", "ABC", Hits().add(0))); - TEST_DO(assertString(fs, "#bc", "ABC", Hits().add(0))); - TEST_DO(assertString(fs, "#ab", "ABC", Hits().add(0))); - TEST_DO(assertString(fs, "#[a-z]", "ABC", Hits().add(0))); - TEST_DO(assertString(fs, "#(zoid)(berg)", "why not zoidberg?", Hits().add(0))); + TEST_DO(assertString(fs, "#abc", "ABC", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "#bc", "ABC", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "#ab", "ABC", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "#[a-z]", "ABC", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "#(zoid)(berg)", "why not zoidberg?", Hits().add({0, 0}))); TEST_DO(assertString(fs, "#[a-z]", "123", Hits())); } @@ -637,23 +714,129 @@ TEST("utf8 flexible searcher handles case-sensitive regex matching") { UTF8FlexibleStringFieldSearcher fs(0); fs.normalize_mode(Normalizing::NONE); TEST_DO(assertString(fs, "#abc", "ABC", Hits())); - TEST_DO(assertString(fs, "#abc", "abc", Hits().add(0))); - TEST_DO(assertString(fs, "#[A-Z]", "A", Hits().add(0))); - TEST_DO(assertString(fs, "#[A-Z]", "ABC", Hits().add(0))); + TEST_DO(assertString(fs, "#abc", "abc", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "#[A-Z]", "A", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "#[A-Z]", "ABC", Hits().add({0, 0}))); TEST_DO(assertString(fs, "#[A-Z]", "abc", Hits())); } TEST("utf8 flexible searcher handles regexes with explicit anchoring") { UTF8FlexibleStringFieldSearcher fs(0); - TEST_DO(assertString(fs, "#^foo", "food", Hits().add(0))); + TEST_DO(assertString(fs, "#^foo", "food", Hits().add({0, 0}))); TEST_DO(assertString(fs, "#^foo", "afoo", Hits())); - TEST_DO(assertString(fs, "#foo$", "afoo", Hits().add(0))); + TEST_DO(assertString(fs, "#foo$", "afoo", Hits().add({0, 0}))); TEST_DO(assertString(fs, "#foo$", "food", Hits())); - TEST_DO(assertString(fs, "#^foo$", "foo", Hits().add(0))); + TEST_DO(assertString(fs, "#^foo$", "foo", Hits().add({0, 0}))); TEST_DO(assertString(fs, "#^foo$", "food", Hits())); TEST_DO(assertString(fs, "#^foo$", "oo", Hits())); } +TEST("utf8 flexible searcher regex matching treats field as 1 word") { + UTF8FlexibleStringFieldSearcher fs(0); + // Match case + TEST_DO(assertFieldInfo(fs, "#.*", "foo bar baz", QTFieldInfo(0, 1, 1))); + // Mismatch case + TEST_DO(assertFieldInfo(fs, "#^zoid$", "foo bar baz", QTFieldInfo(0, 0, 1))); +} + +TEST("utf8 flexible searcher handles fuzzy search in uncased mode") { + UTF8FlexibleStringFieldSearcher fs(0); + // Term syntax (only applies to these tests): + // %{k}term => fuzzy match "term" with max edits k + // %{k,p}term => fuzzy match "term" with max edits k, prefix lock length p + + // DFA is used for k in {1, 2} + TEST_DO(assertString(fs, "%{1}abc", "abc", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1}ABC", "abc", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1}abc", "ABC", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1}Abc", "abd", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1}abc", "ABCD", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1}abc", "abcde", Hits())); + TEST_DO(assertString(fs, "%{2}abc", "abcde", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{2}abc", "xabcde", Hits())); + // Fallback to non-DFA matcher when k not in {1, 2} + TEST_DO(assertString(fs, "%{3}abc", "abc", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{3}abc", "XYZ", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{3}abc", "XYZ!", Hits())); +} + +TEST("utf8 flexible searcher handles fuzzy search in cased mode") { + UTF8FlexibleStringFieldSearcher fs(0); + fs.normalize_mode(Normalizing::NONE); + TEST_DO(assertString(fs, "%{1}abc", "abc", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1}abc", "Abc", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1}ABC", "abc", Hits())); + TEST_DO(assertString(fs, "%{2}Abc", "abc", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{2}abc", "AbC", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{3}abc", "ABC", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{3}abc", "ABCD", Hits())); +} + +TEST("utf8 flexible searcher handles fuzzy search with prefix locking") { + UTF8FlexibleStringFieldSearcher fs(0); + // DFA + TEST_DO(assertString(fs, "%{1,4}zoid", "zoi", Hits())); + TEST_DO(assertString(fs, "%{1,4}zoid", "zoid", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1,4}zoid", "ZOID", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoid", Hits())); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "ZoidBerg", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "ZoidBergg", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidborg", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidblergh", Hits())); + TEST_DO(assertString(fs, "%{2,4}zoidberg", "zoidblergh", Hits().add({0, 0}))); + // Fallback + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidblergh", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidbooorg", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidzooorg", Hits())); + + fs.normalize_mode(Normalizing::NONE); + // DFA + TEST_DO(assertString(fs, "%{1,4}zoid", "ZOID", Hits())); + TEST_DO(assertString(fs, "%{1,4}ZOID", "zoid", Hits())); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidBerg", Hits().add({0, 0}))); // 1 edit + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidBblerg", Hits())); // 2 edits, 1 max + TEST_DO(assertString(fs, "%{2,4}zoidberg", "zoidBblerg", Hits().add({0, 0}))); // 2 edits, 2 max + // Fallback + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidBERG", Hits())); // 4 edits, 3 max + TEST_DO(assertString(fs, "%{4,4}zoidberg", "zoidBERG", Hits().add({0, 0}))); // 4 edits, 4 max +} + +TEST("utf8 flexible searcher fuzzy match with max_edits=0 implies exact match") { + UTF8FlexibleStringFieldSearcher fs(0); + TEST_DO(assertString(fs, "%{0}zoid", "zoi", Hits())); + TEST_DO(assertString(fs, "%{0,4}zoid", "zoi", Hits())); + TEST_DO(assertString(fs, "%{0}zoid", "zoid", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{0}zoid", "ZOID", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{0,4}zoid", "ZOID", Hits().add({0, 0}))); + fs.normalize_mode(Normalizing::NONE); + TEST_DO(assertString(fs, "%{0}zoid", "ZOID", Hits())); + TEST_DO(assertString(fs, "%{0,4}zoid", "ZOID", Hits())); + TEST_DO(assertString(fs, "%{0}zoid", "zoid", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{0,4}zoid", "zoid", Hits().add({0, 0}))); +} + +TEST("utf8 flexible searcher caps oversized fuzzy prefix length to term length") { + UTF8FlexibleStringFieldSearcher fs(0); + // DFA + TEST_DO(assertString(fs, "%{1,5}zoid", "zoid", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1,9001}zoid", "zoid", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{1,9001}zoid", "boid", Hits())); + // Fallback + TEST_DO(assertString(fs, "%{0,5}zoid", "zoid", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{5,5}zoid", "zoid", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{0,9001}zoid", "zoid", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{5,9001}zoid", "zoid", Hits().add({0, 0}))); + TEST_DO(assertString(fs, "%{5,9001}zoid", "boid", Hits())); +} + +TEST("utf8 flexible searcher fuzzy matching treats field as 1 word") { + UTF8FlexibleStringFieldSearcher fs(0); + // Match case + TEST_DO(assertFieldInfo(fs, "%{1}foo bar baz", "foo jar baz", QTFieldInfo(0, 1, 1))); + // Mismatch case + TEST_DO(assertFieldInfo(fs, "%{1}foo", "foo bar baz", QTFieldInfo(0, 0, 1))); +} + TEST("bool search") { BoolFieldSearcher fs(0); TEST_DO(assertBool(fs, "true", true, true)); @@ -692,9 +875,9 @@ TEST("integer search") TEST_DO(assertInt(fs, StringList().add("9").add("10"), 10, BoolList().add(false).add(true))); TEST_DO(assertInt(fs, StringList().add("10").add(">9"), 10, BoolList().add(true).add(true))); - TEST_DO(assertInt(fs, "10", LongList().add(10).add(20).add(10).add(30), Hits().add(0).add(2))); + TEST_DO(assertInt(fs, "10", LongList().add(10).add(20).add(10).add(30), Hits().add({0, 0}).add({2, 0}))); TEST_DO(assertInt(fs, StringList().add("10").add("20"), LongList().add(10).add(20).add(10).add(30), - HitsList().add(Hits().add(0).add(2)).add(Hits().add(1)))); + HitsList().add(Hits().add({0, 0}).add({2, 0})).add(Hits().add({1, 0})))); TEST_DO(assertFieldInfo(fs, "10", 10, QTFieldInfo(0, 1, 1))); TEST_DO(assertFieldInfo(fs, "10", LongList().add(10).add(20).add(10).add(30), QTFieldInfo(0, 2, 4))); @@ -727,9 +910,9 @@ TEST("floating point search") TEST_DO(assertFloat(fs, StringList().add("10").add("10.5"), 10.5, BoolList().add(false).add(true))); TEST_DO(assertFloat(fs, StringList().add(">10.4").add("10.5"), 10.5, BoolList().add(true).add(true))); - TEST_DO(assertFloat(fs, "10.5", FloatList().add(10.5).add(20.5).add(10.5).add(30.5), Hits().add(0).add(2))); + TEST_DO(assertFloat(fs, "10.5", FloatList().add(10.5).add(20.5).add(10.5).add(30.5), Hits().add({0, 0}).add({2, 0}))); TEST_DO(assertFloat(fs, StringList().add("10.5").add("20.5"), FloatList().add(10.5).add(20.5).add(10.5).add(30.5), - HitsList().add(Hits().add(0).add(2)).add(Hits().add(1)))); + HitsList().add(Hits().add({0, 0}).add({2, 0})).add(Hits().add({1, 0})))); TEST_DO(assertFieldInfo(fs, "10.5", 10.5, QTFieldInfo(0, 1, 1))); TEST_DO(assertFieldInfo(fs, "10.5", FloatList().add(10.5).add(20.5).add(10.5).add(30.5), QTFieldInfo(0, 2, 4))); @@ -925,8 +1108,23 @@ TEST("counting of words") { // check that 'a' is counted as 1 word UTF8StrChrFieldSearcher fs(0); StringList field = StringList().add("a").add("aa bb cc"); - assertString(fs, "bb", field, Hits().add(2)); - assertString(fs, StringList().add("bb").add("not"), field, HitsList().add(Hits().add(2)).add(Hits())); + assertString(fs, "bb", field, Hits().add({1, 1})); + assertString(fs, StringList().add("bb").add("not"), field, HitsList().add(Hits().add({1, 1})).add(Hits())); +} + +TEST("element lengths") +{ + UTF8StrChrFieldSearcher fs(0); + auto field = StringList().add("a").add("b a c").add("d a"); + auto query = StringList().add("a"); + auto qtv = performSearch(fs, query, getFieldValue(field)); + EXPECT_EQUAL(1u, qtv.size()); + auto& qt = *qtv[0]; + auto& hl = qt.getHitList(); + EXPECT_EQUAL(3u, hl.size()); + EXPECT_EQUAL(1u, hl[0].element_length()); + EXPECT_EQUAL(3u, hl[1].element_length()); + EXPECT_EQUAL(2u, hl[2].element_length()); } vespalib::string NormalizationInput = "test That Somehing happens with during NÃ¥rmØlization"; diff --git a/streamingvisitors/src/vespa/searchvisitor/matching_elements_filler.cpp b/streamingvisitors/src/vespa/searchvisitor/matching_elements_filler.cpp index 095141c0359..79bacda3f3b 100644 --- a/streamingvisitors/src/vespa/searchvisitor/matching_elements_filler.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/matching_elements_filler.cpp @@ -3,6 +3,7 @@ #include "matching_elements_filler.h" #include <vespa/searchlib/common/matching_elements.h> #include <vespa/searchlib/common/matching_elements_fields.h> +#include <vespa/searchlib/query/streaming/same_element_query_node.h> #include <vespa/searchlib/query/streaming/weighted_set_term.h> #include <vespa/vsm/searcher/fieldsearcher.h> #include <vespa/vdslib/container/searchresult.h> @@ -109,7 +110,7 @@ Matcher::add_matching_elements(const vespalib::string& field_name, uint32_t doc_ { _elements.clear(); for (auto& hit : hit_list) { - _elements.emplace_back(hit.elemId()); + _elements.emplace_back(hit.element_id()); } if (_elements.size() > 1) { std::sort(_elements.begin(), _elements.end()); diff --git a/streamingvisitors/src/vespa/searchvisitor/querywrapper.h b/streamingvisitors/src/vespa/searchvisitor/querywrapper.h index 27b5a2e12d4..b24f695196e 100644 --- a/streamingvisitors/src/vespa/searchvisitor/querywrapper.h +++ b/streamingvisitors/src/vespa/searchvisitor/querywrapper.h @@ -2,6 +2,7 @@ #pragma once +#include <vespa/searchlib/query/streaming/phrase_query_node.h> #include <vespa/searchlib/query/streaming/query.h> #include <vespa/searchlib/query/streaming/querynode.h> diff --git a/streamingvisitors/src/vespa/searchvisitor/rankmanager.cpp b/streamingvisitors/src/vespa/searchvisitor/rankmanager.cpp index cdaf14eef9b..eebd9a79c07 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankmanager.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/rankmanager.cpp @@ -45,10 +45,15 @@ RankManager::Snapshot::addProperties(const vespa::config::search::RankProfilesCo FieldInfo::DataType to_data_type(VsmfieldsConfig::Fieldspec::Searchmethod search_method) { - if (search_method == VsmfieldsConfig::Fieldspec::Searchmethod::NEAREST_NEIGHBOR) { + // detecting DataType from Searchmethod will not give correct results, + // we should probably use the document type + if (search_method == VsmfieldsConfig::Fieldspec::Searchmethod::NEAREST_NEIGHBOR || + search_method == VsmfieldsConfig::Fieldspec::Searchmethod::NONE) + { return FieldInfo::DataType::TENSOR; } // This is the default FieldInfo data type if not specified. + // Wrong in most cases. return FieldInfo::DataType::DOUBLE; } diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp index 6b15b7cb88e..070563859a5 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp @@ -114,7 +114,7 @@ RankProcessor::initQueryEnvironment() void RankProcessor::initHitCollector(size_t wantedHitCount) { - _hitCollector.reset(new HitCollector(wantedHitCount)); + _hitCollector = std::make_unique<HitCollector>(wantedHitCount); } void @@ -209,9 +209,8 @@ class RankProgramWrapper : public HitCollector::IRankProgram { private: MatchData &_match_data; - public: - RankProgramWrapper(MatchData &match_data) : _match_data(match_data) {} + explicit RankProgramWrapper(MatchData &match_data) : _match_data(match_data) {} void run(uint32_t docid, const std::vector<search::fef::TermFieldMatchData> &matchData) override { // Prepare the match data object used by the rank program with earlier unpacked match data. copyTermFieldMatchData(matchData, _match_data); @@ -300,7 +299,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap // optimize for hitlist giving all hits for a single field in one chunk for (const Hit & hit : hitList) { - uint32_t fieldId = hit.context(); + uint32_t fieldId = hit.field_id(); if (fieldId != lastFieldId) { // reset to notfound/unknown values tmd = nullptr; @@ -335,8 +334,8 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap } if (tmd != nullptr) { // adjust so that the position for phrase terms equals the match for the first term - TermFieldMatchDataPosition pos(hit.elemId(), hit.wordpos() - term.getPosAdjust(), - hit.weight(), fieldLen); + TermFieldMatchDataPosition pos(hit.element_id(), hit.position() - term.getPosAdjust(), + hit.element_weight(), hit.element_length()); tmd->appendPosition(pos); LOG(debug, "Append elemId(%u),position(%u), weight(%d), tfmd.weight(%d)", pos.getElementId(), pos.getPosition(), pos.getElementWeight(), tmd->getWeight()); diff --git a/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp b/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp index cdd1a018d84..979e5f25b6a 100644 --- a/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp @@ -391,7 +391,17 @@ SearchVisitor::init(const Parameters & params) } _queryResult->getSearchResult().setWantedHitCount(wantedSummaryCount); + vespalib::stringref sortRef; + bool hasSortSpec = params.lookup("sort", sortRef); + vespalib::stringref groupingRef; + bool hasGrouping = params.lookup("aggregation", groupingRef); + if (params.lookup("rankprofile", valueRef) ) { + if ( ! hasGrouping && (wantedSummaryCount == 0)) { + // If no hits and no grouping, just use unranked profile + // TODO, optional could also include check for if grouping needs rank + valueRef = "unranked"; + } vespalib::string tmp(valueRef.data(), valueRef.size()); _rankController.setRankProfile(tmp); LOG(debug, "Received rank profile: %s", _rankController.getRankProfile().c_str()); @@ -442,9 +452,9 @@ SearchVisitor::init(const Parameters & params) if (_env) { _init_called = true; - if ( params.lookup("sort", valueRef) ) { + if ( hasSortSpec ) { search::uca::UcaConverterFactory ucaFactory; - _sortSpec = search::common::SortSpec(vespalib::string(valueRef.data(), valueRef.size()), ucaFactory); + _sortSpec = search::common::SortSpec(vespalib::string(sortRef.data(), sortRef.size()), ucaFactory); LOG(debug, "Received sort specification: '%s'", _sortSpec.getSpec().c_str()); } @@ -494,10 +504,10 @@ SearchVisitor::init(const Parameters & params) LOG(warning, "No query received"); } - if (params.lookup("aggregation", valueRef) ) { + if (hasGrouping) { std::vector<char> newAggrBlob; - newAggrBlob.resize(valueRef.size()); - memcpy(&newAggrBlob[0], valueRef.data(), newAggrBlob.size()); + newAggrBlob.resize(groupingRef.size()); + memcpy(&newAggrBlob[0], groupingRef.data(), newAggrBlob.size()); LOG(debug, "Received new aggregation blob of %zd bytes", newAggrBlob.size()); setupGrouping(newAggrBlob); } diff --git a/streamingvisitors/src/vespa/vsm/searcher/boolfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/boolfieldsearcher.cpp index d0cfa4d9956..aa25b0e75d3 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/boolfieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/boolfieldsearcher.cpp @@ -53,7 +53,7 @@ void BoolFieldSearcher::onValue(const document::FieldValue & fv) addHit(*_qtl[j], 0); } } - ++_words; + set_element_length(1); } } diff --git a/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.cpp index 5e06ae41a03..c75ab7fccd3 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.cpp @@ -5,6 +5,7 @@ #include <vespa/document/fieldvalue/weightedsetfieldvalue.h> #include <vespa/searchlib/query/streaming/multi_term.h> #include <vespa/vespalib/stllike/hash_set.h> +#include <cassert> #include <vespa/log/log.h> LOG_SETUP(".vsm.searcher.fieldsearcher"); @@ -55,6 +56,7 @@ FieldSearcher::FieldSearcher(FieldIdT fId, bool defaultPrefix) noexcept _maxFieldLength(0x100000), _currentElementId(0), _currentElementWeight(1), + _element_length_fixups(), _words(0), _badUtf8Count(0) { @@ -70,6 +72,7 @@ FieldSearcher::search(const StorageDocument & doc) fInfo.setHitOffset(qt->getHitList().size()); } onSearch(doc); + assert(_element_length_fixups.empty()); for (auto qt : _qtl) { QueryTerm::FieldInfo & fInfo = qt->getFieldInfo(field()); fInfo.setHitCount(qt->getHitList().size() - fInfo.getHitOffset()); @@ -276,4 +279,16 @@ FieldSearcher::IteratorHandler::onStructStart(const Content & c) _searcher.onStructValue(static_cast<const document::StructFieldValue &>(c.getValue())); } +void +FieldSearcher::set_element_length(uint32_t element_length) +{ + _words += element_length; + if (!_element_length_fixups.empty()) { + for (auto& fixup : _element_length_fixups) { + fixup.first->set_element_length(fixup.second, element_length); + } + _element_length_fixups.clear(); + } +} + } diff --git a/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h b/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h index c5bca6f3899..2af68c553b8 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h +++ b/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h @@ -6,6 +6,7 @@ #include <vespa/vsm/common/document.h> #include <vespa/vsm/common/storagedocument.h> #include <vespa/vespalib/util/array.h> +#include <utility> namespace search::fef { class IQueryEnvironment; } @@ -46,7 +47,7 @@ public: explicit FieldSearcher(FieldIdT fId) noexcept : FieldSearcher(fId, false) {} FieldSearcher(FieldIdT fId, bool defaultPrefix) noexcept; ~FieldSearcher() override; - virtual std::unique_ptr<FieldSearcher> duplicate() const = 0; + [[nodiscard]] virtual std::unique_ptr<FieldSearcher> duplicate() const = 0; bool search(const StorageDocument & doc); virtual void prepare(search::streaming::QueryTermList& qtl, const SharedSearcherBuf& buf, const vsm::FieldPathMapT& field_paths, search::fef::IQueryEnvironment& query_env); @@ -96,6 +97,7 @@ private: unsigned _maxFieldLength; uint32_t _currentElementId; int32_t _currentElementWeight; // Contains the weight of the current item being evaluated. + std::vector<std::pair<search::streaming::QueryTerm*, uint32_t>> _element_length_fixups; protected: /// Number of terms searched. unsigned _words; @@ -105,9 +107,10 @@ protected: * Adds a hit to the given query term. * For each call to onValue() a batch of words are processed, and the position is local to this batch. **/ - void addHit(search::streaming::QueryTerm & qt, uint32_t pos) const { - qt.add(_words + pos, field(), _currentElementId, _currentElementWeight); + void addHit(search::streaming::QueryTerm & qt, uint32_t pos) { + _element_length_fixups.emplace_back(&qt, qt.add(field(), _currentElementId, _currentElementWeight, pos)); } + void set_element_length(uint32_t element_length); public: static search::byte _foldLowCase[256]; static search::byte _wordChar[256]; diff --git a/streamingvisitors/src/vespa/vsm/searcher/floatfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/floatfieldsearcher.cpp index 8558522003f..70e5bb4b82c 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/floatfieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/floatfieldsearcher.cpp @@ -55,7 +55,7 @@ void FloatFieldSearcherT<T>::onValue(const document::FieldValue & fv) addHit(*_qtl[j], 0); } } - ++_words; + set_element_length(1); } template<typename T> diff --git a/streamingvisitors/src/vespa/vsm/searcher/geo_pos_field_searcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/geo_pos_field_searcher.cpp index 5ecc9a5a06e..bbeb3be986f 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/geo_pos_field_searcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/geo_pos_field_searcher.cpp @@ -58,7 +58,7 @@ void GeoPosFieldSearcher::onStructValue(const document::StructFieldValue & fv) { addHit(*_qtl[j], 0); } } - ++_words; + set_element_length(1); } bool GeoPosFieldSearcher::GeoPosInfo::cmp(const document::StructFieldValue & sfv) const { diff --git a/streamingvisitors/src/vespa/vsm/searcher/intfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/intfieldsearcher.cpp index e73c7f5c1a7..3984254274f 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/intfieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/intfieldsearcher.cpp @@ -43,7 +43,7 @@ void IntFieldSearcher::onValue(const document::FieldValue & fv) addHit(*_qtl[j], 0); } } - ++_words; + set_element_length(1); } } diff --git a/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp index c0a0249125f..673cf11b2cf 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp @@ -25,36 +25,42 @@ void StrChrFieldSearcher::onValue(const document::FieldValue & fv) bool StrChrFieldSearcher::matchDoc(const FieldRef & fieldRef) { - if (_qtl.size() > 1) { - size_t mintsz = shortestTerm(); - if (fieldRef.size() >= mintsz) { - _words += matchTerms(fieldRef, mintsz); + size_t element_length = 0; + bool need_count_words = false; + if (_qtl.size() > 1) { + size_t mintsz = shortestTerm(); + if (fieldRef.size() >= mintsz) { + element_length = matchTerms(fieldRef, mintsz); + } else { + need_count_words = true; + } } else { - _words += countWords(fieldRef); + for (auto qt : _qtl) { + if (fieldRef.size() >= qt->termLen() || qt->isRegex() || qt->isFuzzy()) { + element_length = std::max(element_length, matchTerm(fieldRef, *qt)); + } else { + need_count_words = true; + } + } } - } else { - for (auto qt : _qtl) { - if (fieldRef.size() >= qt->termLen() || qt->isRegex()) { - _words += matchTerm(fieldRef, *qt); - } else { - _words += countWords(fieldRef); - } + if (need_count_words) { + element_length = std::max(element_length, countWords(fieldRef)); } - } - return true; + set_element_length(element_length); + return true; } size_t StrChrFieldSearcher::shortestTerm() const { - size_t mintsz(_qtl.front()->termLen()); - for (auto it=_qtl.begin()+1, mt=_qtl.end(); it != mt; it++) { - const QueryTerm & qt = **it; - if (qt.isRegex()) { - return 0; // Must avoid "too short query term" optimization when using regex + size_t mintsz(_qtl.front()->termLen()); + for (auto it=_qtl.begin()+1, mt=_qtl.end(); it != mt; it++) { + const QueryTerm & qt = **it; + if (qt.isRegex() || qt.isFuzzy()) { + return 0; // Must avoid "too short query term" optimization when using regex or fuzzy + } + mintsz = std::min(mintsz, qt.termLen()); } - mintsz = std::min(mintsz, qt.termLen()); - } - return mintsz; + return mintsz; } } diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp index c6deb6eacd1..4a8e7a43475 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp @@ -1,6 +1,8 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "utf8flexiblestringfieldsearcher.h" +#include <vespa/searchlib/query/streaming/fuzzy_term.h> #include <vespa/searchlib/query/streaming/regexp_term.h> +#include <algorithm> #include <cassert> #include <vespa/log/log.h> @@ -23,7 +25,7 @@ UTF8FlexibleStringFieldSearcher::matchTerms(const FieldRef & f, const size_t min (void) mintsz; size_t words = 0; for (auto qt : _qtl) { - words = matchTerm(f, *qt); + words = std::max(words, matchTerm(f, *qt)); } return words; } @@ -36,7 +38,20 @@ UTF8FlexibleStringFieldSearcher::match_regexp(const FieldRef & f, search::stream if (regexp_term->regexp().partial_match({f.data(), f.size()})) { addHit(qt, 0); } - return countWords(f); + return 1; +} + +size_t +UTF8FlexibleStringFieldSearcher::match_fuzzy(const FieldRef & f, search::streaming::QueryTerm & qt) +{ + auto* fuzzy_term = qt.as_fuzzy_term(); + assert(fuzzy_term != nullptr); + // TODO delegate to matchTermExact if max edits == 0? + // - needs to avoid folding to have consistent normalization semantics + if (fuzzy_term->is_match({f.data(), f.size()})) { + addHit(qt, 0); + } + return 1; } size_t @@ -57,6 +72,9 @@ UTF8FlexibleStringFieldSearcher::matchTerm(const FieldRef & f, QueryTerm & qt) } else if (qt.isRegex()) { LOG(debug, "Use regexp match for term '%s:%s'", qt.index().c_str(), qt.getTerm()); return match_regexp(f, qt); + } else if (qt.isFuzzy()) { + LOG(debug, "Use fuzzy match for term '%s:%s'", qt.index().c_str(), qt.getTerm()); + return match_fuzzy(f, qt); } else { if (substring()) { LOG(debug, "Use substring match for term '%s:%s'", qt.index().c_str(), qt.getTerm()); diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h index cd1715ad158..a5f6ad46246 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h +++ b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h @@ -25,6 +25,7 @@ private: size_t matchTerms(const FieldRef & f, size_t shortestTerm) override; size_t match_regexp(const FieldRef & f, search::streaming::QueryTerm & qt); + size_t match_fuzzy(const FieldRef & f, search::streaming::QueryTerm & qt); public: std::unique_ptr<FieldSearcher> duplicate() const override; diff --git a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp index 9c8bb2f185a..63d2007cecf 100644 --- a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp +++ b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp @@ -133,7 +133,7 @@ FieldSearchSpec::reconfig(const QueryTerm & term) (term.isSuffix() && _arg1 != "suffix") || (term.isExactstring() && _arg1 != "exact") || (term.isPrefix() && _arg1 == "suffix") || - term.isRegex()) + (term.isRegex() || term.isFuzzy())) { _searcher = std::make_unique<UTF8FlexibleStringFieldSearcher>(id()); propagate_settings_to_searcher(); diff --git a/vespa-athenz/pom.xml b/vespa-athenz/pom.xml index f807f7c28be..cac79c3850e 100644 --- a/vespa-athenz/pom.xml +++ b/vespa-athenz/pom.xml @@ -243,32 +243,6 @@ </exclusions> </dependency> <dependency> - <groupId>com.amazonaws</groupId> - <artifactId>aws-java-sdk-core</artifactId> - <exclusions> - <exclusion> - <groupId>com.fasterxml.jackson.core</groupId> - <artifactId>jackson-core</artifactId> - </exclusion> - <exclusion> - <groupId>com.fasterxml.jackson.core</groupId> - <artifactId>jackson-databind</artifactId> - </exclusion> - <exclusion> - <groupId>com.fasterxml.jackson.core</groupId> - <artifactId>jackson-annotations</artifactId> - </exclusion> - <exclusion> - <groupId>org.apache.httpcomponents</groupId> - <artifactId>httpclient</artifactId> - </exclusion> - <exclusion> - <groupId>commons-logging</groupId> - <artifactId>commons-logging</artifactId> - </exclusion> - </exclusions> - </dependency> - <dependency> <groupId>com.auth0</groupId> <artifactId>java-jwt</artifactId> <exclusions> diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/DefaultZmsClient.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/DefaultZmsClient.java index 564c1144cc0..c4c8fac87b4 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/DefaultZmsClient.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/DefaultZmsClient.java @@ -204,6 +204,7 @@ public class DefaultZmsClient extends ClientBase implements ZmsClient { .build(); return execute(request, response -> { DomainListResponseEntity result = readEntity(response, DomainListResponseEntity.class); + if (result.domains == null) return List.of(); return result.domains.stream().map(AthenzDomain::new).toList(); }); } @@ -216,6 +217,7 @@ public class DefaultZmsClient extends ClientBase implements ZmsClient { .build(); return execute(request, response -> { DomainListResponseEntity result = readEntity(response, DomainListResponseEntity.class); + if (result.domains == null) return List.of(); return result.domains.stream().map(AthenzDomain::new).toList(); }); } diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java b/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java index 19e0c0dc77d..5ff7b4592a1 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java @@ -435,7 +435,8 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler { return ignoredContent; } - private ContentChannel getDocument(HttpRequest request, DocumentPath path, ResponseHandler handler) { + private ContentChannel getDocument(HttpRequest request, DocumentPath path, ResponseHandler rawHandler) { + ResponseHandler handler = new MeasuringResponseHandler(request, rawHandler, com.yahoo.documentapi.metrics.DocumentOperationType.GET, clock.instant()); disallow(request, DRY_RUN); enqueueAndDispatch(request, handler, () -> { DocumentOperationParameters rawParameters = parametersFromRequest(request, CLUSTER, FIELD_SET); @@ -1057,7 +1058,7 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler { private ParsedDocumentOperation parse(InputStream inputStream, String docId, DocumentOperationType operation) { try { - return new JsonReader(manager, inputStream, jsonFactory).readSingleDocument(operation, docId); + return new JsonReader(manager, inputStream, jsonFactory).readSingleDocumentStreaming(operation, docId); } catch (IllegalArgumentException e) { incrementMetricParseError(); throw e; diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/documentapi/metrics/DocumentOperationType.java b/vespaclient-container-plugin/src/main/java/com/yahoo/documentapi/metrics/DocumentOperationType.java index 1c0b8c560ac..63bf520f4d3 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/documentapi/metrics/DocumentOperationType.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/documentapi/metrics/DocumentOperationType.java @@ -11,7 +11,7 @@ import com.yahoo.messagebus.Message; */ public enum DocumentOperationType { - PUT, REMOVE, UPDATE, ERROR; + GET, PUT, REMOVE, UPDATE, ERROR; public static DocumentOperationType fromMessage(Message msg) { if (msg instanceof PutDocumentMessage) { diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/MetricNames.java b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/MetricNames.java index bf740014edd..efcffb16a2b 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/MetricNames.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/MetricNames.java @@ -7,9 +7,10 @@ import ai.vespa.metrics.ContainerMetrics; * Place to store the metric names so where the metrics are logged can be found * more easily in an IDE. * - * @author steinar + * @author Steinar Knutsen */ public final class MetricNames { + public static final String NUM_OPERATIONS = ContainerMetrics.HTTPAPI_NUM_OPERATIONS.baseName(); public static final String NUM_PUTS = ContainerMetrics.HTTPAPI_NUM_PUTS.baseName(); public static final String NUM_REMOVES = ContainerMetrics.HTTPAPI_NUM_REMOVES.baseName(); @@ -26,7 +27,6 @@ public final class MetricNames { public static final String FAILED_TIMEOUT = ContainerMetrics.HTTPAPI_FAILED_TIMEOUT.baseName(); public static final String FAILED_INSUFFICIENT_STORAGE = ContainerMetrics.HTTPAPI_FAILED_INSUFFICIENT_STORAGE.baseName(); - private MetricNames() { - } + private MetricNames() { } } diff --git a/vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java b/vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java index c8fcb4c4635..04639db4dac 100644 --- a/vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java +++ b/vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java @@ -411,6 +411,7 @@ public class DocumentV1ApiTest { DocumentUpdate expectedUpdate = new DocumentUpdate(doc3.getDataType(), doc3.getId()); expectedUpdate.addFieldUpdate(FieldUpdate.createAssign(doc3.getField("artist"), new StringFieldValue("Lisa Ekdahl"))); expectedUpdate.setCondition(new TestAndSetCondition("true")); + expectedUpdate.setCreateIfNonExistent(true); assertEquals(expectedUpdate, update); parameters.responseHandler().get().handleResponse(new UpdateResponse(0, false)); assertEquals(parameters().withRoute("content"), parameters); @@ -419,10 +420,16 @@ public class DocumentV1ApiTest { response = driver.sendRequest("http://localhost/document/v1/space/music/docid?selection=true&cluster=content&timeChunk=10", PUT, """ { + "extra-ignored-field": { "foo": [{ }], "bar": null }, + "another-ignored-field": [{ "foo": [{ }] }], + "remove": "id:ns:type::ignored", + "put": "id:ns:type::ignored", "fields": { "artist": { "assign": "Lisa Ekdahl" }, "nonexisting": { "assign": "Ignored" } - } + }, + "post": "id:ns:type::ignored", + "create": true }"""); assertSameJson(""" { @@ -778,7 +785,7 @@ public class DocumentV1ApiTest { response = driver.sendRequest("http://localhost/document/v1/space/music/number/1/two?condition=test%20it", POST, ""); assertSameJson("{" + " \"pathId\": \"/document/v1/space/music/number/1/two\"," + - " \"message\": \"Could not read document, no document?\"" + + " \"message\": \"expected start of root object, got null\"" + "}", response.readAll()); assertEquals(400, response.getStatus()); @@ -791,7 +798,8 @@ public class DocumentV1ApiTest { "}"); Inspector responseRoot = SlimeUtils.jsonToSlime(response.readAll()).get(); assertEquals("/document/v1/space/music/number/1/two", responseRoot.field("pathId").asString()); - assertTrue(responseRoot.field("message").asString().startsWith("Unexpected character ('â”»' (code 9531 / 0x253b)): was expecting double-quote to start field name")); + assertTrue(responseRoot.field("message").asString(), + responseRoot.field("message").asString().startsWith("failed parsing document: Unexpected character ('â”»' (code 9531 / 0x253b)): was expecting double-quote to start field name")); assertEquals(400, response.getStatus()); // PUT on a unknown document type is a 400 diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 10b0478b5b0..7f77582ea81 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -720,6 +720,22 @@ ], "fields" : [ ] }, + "com.yahoo.tensor.DirectIndexedAddress" : { + "superClass" : "java.lang.Object", + "interfaces" : [ ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public static com.yahoo.tensor.DirectIndexedAddress of(com.yahoo.tensor.DimensionSizes)", + "public void setIndex(int, int)", + "public long getDirectIndex()", + "public long[] getIndexes()", + "public long getStride(int)" + ], + "fields" : [ ] + }, "com.yahoo.tensor.IndexedDoubleTensor$BoundDoubleBuilder" : { "superClass" : "com.yahoo.tensor.IndexedTensor$BoundBuilder", "interfaces" : [ ], @@ -894,8 +910,11 @@ "public java.util.Iterator subspaceIterator(java.util.Set, com.yahoo.tensor.DimensionSizes)", "public java.util.Iterator subspaceIterator(java.util.Set)", "public varargs double get(long[])", + "public double get(com.yahoo.tensor.DirectIndexedAddress)", + "public com.yahoo.tensor.DirectIndexedAddress directAddress()", "public varargs float getFloat(long[])", "public double get(com.yahoo.tensor.TensorAddress)", + "public java.lang.Double getAsDouble(com.yahoo.tensor.TensorAddress)", "public boolean has(com.yahoo.tensor.TensorAddress)", "public abstract double get(long)", "public abstract float getFloat(long)", @@ -949,8 +968,10 @@ "methods" : [ "public com.yahoo.tensor.TensorType type()", "public long size()", + "public int sizeAsInt()", "public double get(com.yahoo.tensor.TensorAddress)", "public boolean has(com.yahoo.tensor.TensorAddress)", + "public java.lang.Double getAsDouble(com.yahoo.tensor.TensorAddress)", "public java.util.Iterator cellIterator()", "public java.util.Iterator valueIterator()", "public java.util.Map cells()", @@ -1031,6 +1052,7 @@ "public com.yahoo.tensor.TensorType type()", "public long size()", "public double get(com.yahoo.tensor.TensorAddress)", + "public java.lang.Double getAsDouble(com.yahoo.tensor.TensorAddress)", "public boolean has(com.yahoo.tensor.TensorAddress)", "public java.util.Iterator cellIterator()", "public java.util.Iterator valueIterator()", @@ -1153,9 +1175,11 @@ "methods" : [ "public abstract com.yahoo.tensor.TensorType type()", "public boolean isEmpty()", - "public abstract long size()", + "public long size()", + "public int sizeAsInt()", "public abstract double get(com.yahoo.tensor.TensorAddress)", "public abstract boolean has(com.yahoo.tensor.TensorAddress)", + "public abstract java.lang.Double getAsDouble(com.yahoo.tensor.TensorAddress)", "public abstract java.util.Iterator cellIterator()", "public abstract java.util.Iterator valueIterator()", "public abstract java.util.Map cells()", @@ -1243,7 +1267,9 @@ "public static com.yahoo.tensor.Tensor from(java.lang.String)", "public static com.yahoo.tensor.Tensor from(double)" ], - "fields" : [ ] + "fields" : [ + "public static final int invalidIndex" + ] }, "com.yahoo.tensor.TensorAddress$Builder" : { "superClass" : "java.lang.Object", @@ -1255,6 +1281,8 @@ "public void <init>(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.TensorAddress$Builder add(java.lang.String)", "public com.yahoo.tensor.TensorAddress$Builder add(java.lang.String, java.lang.String)", + "public com.yahoo.tensor.TensorAddress$Builder add(java.lang.String, long)", + "public com.yahoo.tensor.TensorAddress$Builder add(java.lang.String, int)", "public com.yahoo.tensor.TensorAddress$Builder copy()", "public com.yahoo.tensor.TensorType type()", "public com.yahoo.tensor.TensorAddress build()" @@ -1287,16 +1315,19 @@ "public static com.yahoo.tensor.TensorAddress of(java.lang.String[])", "public static varargs com.yahoo.tensor.TensorAddress ofLabels(java.lang.String[])", "public static varargs com.yahoo.tensor.TensorAddress of(long[])", + "public static varargs com.yahoo.tensor.TensorAddress of(int[])", "public abstract int size()", "public abstract java.lang.String label(int)", "public abstract long numericLabel(int)", "public abstract com.yahoo.tensor.TensorAddress withLabel(int, long)", "public final boolean isEmpty()", "public int compareTo(com.yahoo.tensor.TensorAddress)", - "public int hashCode()", - "public boolean equals(java.lang.Object)", + "public java.lang.String toString()", "public final java.lang.String toString(com.yahoo.tensor.TensorType)", "public static java.lang.String labelToString(java.lang.String)", + "public com.yahoo.tensor.TensorAddress partialCopy(int[])", + "public com.yahoo.tensor.TensorAddress fullAddressOf(java.util.List, int[])", + "public com.yahoo.tensor.TensorAddress mappedPartialAddress(com.yahoo.tensor.TensorType, java.util.List)", "public bridge synthetic int compareTo(java.lang.Object)" ], "fields" : [ ] @@ -1452,6 +1483,9 @@ ], "methods" : [ "public void <init>(com.yahoo.tensor.TensorType$Value, java.util.Collection)", + "public boolean hasIndexedDimensions()", + "public boolean hasMappedDimensions()", + "public boolean hasOnlyIndexedBoundDimensions()", "public static varargs com.yahoo.tensor.TensorType$Value combinedValueType(com.yahoo.tensor.TensorType[])", "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", "public com.yahoo.tensor.TensorType$Value valueType()", @@ -1462,6 +1496,7 @@ "public java.util.Set dimensionNames()", "public java.util.Optional dimension(java.lang.String)", "public java.util.Optional indexOfDimension(java.lang.String)", + "public int indexOfDimensionAsInt(java.lang.String)", "public java.util.Optional sizeOfDimension(java.lang.String)", "public boolean isAssignableTo(com.yahoo.tensor.TensorType)", "public boolean isConvertibleTo(com.yahoo.tensor.TensorType)", diff --git a/vespajlib/src/main/java/com/yahoo/compress/Hasher.java b/vespajlib/src/main/java/com/yahoo/compress/Hasher.java index 92a9ed26085..7a3d34eca7b 100644 --- a/vespajlib/src/main/java/com/yahoo/compress/Hasher.java +++ b/vespajlib/src/main/java/com/yahoo/compress/Hasher.java @@ -8,8 +8,25 @@ import net.openhft.hashing.LongHashFunction; * @author baldersheim */ public class Hasher { + private final LongHashFunction hasher; /** Uses net.openhft.hashing.LongHashFunction.xx3() */ public static long xxh3(byte [] data) { return LongHashFunction.xx3().hashBytes(data); } + public static long xxh3(byte [] data, long seed) { + return LongHashFunction.xx3(seed).hashBytes(data); + } + + private Hasher(LongHashFunction hasher) { + this.hasher = hasher; + } + public static Hasher withSeed(long seed) { + return new Hasher(LongHashFunction.xx3(seed)); + } + public long hash(long v) { + return hasher.hashLong(v); + } + public long hash(String s) { + return hasher.hashChars(s); + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index 83a625f72ac..640fa609432 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -11,10 +11,19 @@ import java.util.Arrays; public final class DimensionSizes { private final long[] sizes; + private final long[] productOfSizesFromHereOn; + private final long totalSize; private DimensionSizes(Builder builder) { this.sizes = builder.sizes; builder.sizes = null; // invalidate builder to avoid copying the array + this.productOfSizesFromHereOn = new long[sizes.length]; + long product = 1; + for (int i = sizes.length; i-- > 0; ) { + productOfSizesFromHereOn[i] = product; + product *= sizes[i]; + } + this.totalSize = product; } /** @@ -49,10 +58,11 @@ public final class DimensionSizes { /** Returns the product of the sizes of this */ public long totalSize() { - long productSize = 1; - for (long dimensionSize : sizes ) - productSize *= dimensionSize; - return productSize; + return totalSize; + } + + long productOfDimensionsAfter(int afterIndex) { + return productOfSizesFromHereOn[afterIndex]; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java new file mode 100644 index 00000000000..cda3be47ddb --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java @@ -0,0 +1,55 @@ +package com.yahoo.tensor; + +/** + * Utility class for efficient access and iteration along dimensions in Indexed tensors. + * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration. + * long base = addr.getDirectIndex(); + * long stride = addr.getStride(dimension) + * i = 0...size_of_dimension + * double value = tensor.get(base + i * stride); + * + * @author baldersheim + */ +public final class DirectIndexedAddress { + + private final DimensionSizes sizes; + private final int[] indexes; + private long directIndex; + + private DirectIndexedAddress(DimensionSizes sizes) { + this.sizes = sizes; + indexes = new int[sizes.dimensions()]; + directIndex = 0; + } + + public static DirectIndexedAddress of(DimensionSizes sizes) { + return new DirectIndexedAddress(sizes); + } + + /** Sets the current index of a dimension */ + public void setIndex(int dimension, int index) { + if (index < 0 || index >= sizes.size(dimension)) { + throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">"); + } + int diff = index - indexes[dimension]; + directIndex += getStride(dimension) * diff; + indexes[dimension] = index; + } + + /** Retrieve the index that can be used for direct lookup in an indexed tensor. */ + public long getDirectIndex() { return directIndex; } + + public long [] getIndexes() { + long[] asLong = new long[indexes.length]; + for (int i=0; i < indexes.length; i++) { + asLong[i] = indexes[i]; + } + return asLong; + } + + /** returns the stride to be used for the given dimension */ + public long getStride(int dimension) { + return sizes.productOfDimensionsAfter(dimension); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java index 548d39dd767..53f50fc4d02 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -22,6 +22,10 @@ class IndexedDoubleTensor extends IndexedTensor { return values.length; } + /** Once we can store more cells than an int we should drop this method. */ + @Override + public int sizeAsInt() { return values.length; } + @Override public double get(long valueIndex) { return values[(int)valueIndex]; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java index 26560a70ac4..3085ef1a843 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java @@ -18,9 +18,11 @@ class IndexedFloatTensor extends IndexedTensor { } @Override - public long size() { - return values.length; - } + public long size() { return values.length; } + + /** Once we can store more cells than an int we should drop this. */ + @Override + public int sizeAsInt() { return values.length; } @Override public double get(long valueIndex) { return getFloat(valueIndex); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 6a879fa533b..fc0473c635a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -90,9 +90,13 @@ public abstract class IndexedTensor implements Tensor { * @throws IllegalArgumentException if any of the indexes are out of bound or a wrong number of indexes are given */ public double get(long ... indexes) { - return get((int)toValueIndex(indexes, dimensionSizes)); + return get(toValueIndex(indexes, dimensionSizes)); } + public double get(DirectIndexedAddress address) { + return get(address.getDirectIndex()); + } + public DirectIndexedAddress directAddress() { return DirectIndexedAddress.of(dimensionSizes); } /** * Returns the value at the given indexes as a float * @@ -108,7 +112,7 @@ public abstract class IndexedTensor implements Tensor { public double get(TensorAddress address) { // optimize for fast lookup within bounds: try { - return get((int)toValueIndex(address, dimensionSizes, type)); + return get(toValueIndex(address, dimensionSizes, type)); } catch (IllegalArgumentException e) { return 0.0; @@ -116,6 +120,17 @@ public abstract class IndexedTensor implements Tensor { } @Override + public Double getAsDouble(TensorAddress address) { + try { + long index = toValueIndex(address, dimensionSizes, type); + if (index < 0 || size() <= index) return null; + return get(index); + } catch (IllegalArgumentException e) { + return null; + } + } + + @Override public boolean has(TensorAddress address) { try { long index = toValueIndex(address, dimensionSizes, type); @@ -150,30 +165,22 @@ public abstract class IndexedTensor implements Tensor { for (int i = 0; i < indexes.length; i++) { if (indexes[i] >= sizes.size(i)) throw new IllegalArgumentException(Arrays.toString(indexes) + " are not within bounds"); - valueIndex += productOfDimensionsAfter(i, sizes) * indexes[i]; + valueIndex += sizes.productOfDimensionsAfter(i) * indexes[i]; } return valueIndex; } static long toValueIndex(TensorAddress address, DimensionSizes sizes, TensorType type) { - if (address.isEmpty()) return 0; - long valueIndex = 0; - for (int i = 0; i < address.size(); i++) { - if (address.numericLabel(i) >= sizes.size(i)) + for (int i = 0, size = address.size(); i < size; i++) { + long label = address.numericLabel(i); + if (label >= sizes.size(i)) throw new IllegalArgumentException(address + " is not within the bounds of " + type); - valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i); + valueIndex += sizes.productOfDimensionsAfter(i) * label; } return valueIndex; } - private static long productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) { - long product = 1; - for (int i = afterIndex + 1; i < sizes.dimensions(); i++) - product *= sizes.size(i); - return product; - } - void throwOnIncompatibleType(TensorType type) { if ( ! this.type().isRenamableTo(type)) throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type + @@ -227,7 +234,7 @@ public abstract class IndexedTensor implements Tensor { @Override public String toAbbreviatedString(boolean withType, boolean shortForms) { - return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1))); + return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1))); } private String toString(boolean withType, boolean shortForms, long maxCells) { @@ -250,8 +257,7 @@ public abstract class IndexedTensor implements Tensor { b.append(", "); // start brackets - for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) - b.append("["); + b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart()))); // value switch (tensor.type().valueType()) { @@ -264,8 +270,7 @@ public abstract class IndexedTensor implements Tensor { } // end bracket and comma - for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) - b.append("]"); + b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd()))); } if (index == maxCells && index < tensor.size()) b.append(", ...]"); @@ -286,7 +291,7 @@ public abstract class IndexedTensor implements Tensor { } public static Builder of(TensorType type) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type)); else return new UnboundBuilder(type); @@ -300,7 +305,7 @@ public abstract class IndexedTensor implements Tensor { * must not be further mutated by the caller */ public static Builder of(TensorType type, float[] values) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type), values); else return new UnboundBuilder(type); @@ -314,7 +319,7 @@ public abstract class IndexedTensor implements Tensor { * must not be further mutated by the caller */ public static Builder of(TensorType type, double[] values) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type), values); else return new UnboundBuilder(type); @@ -327,14 +332,13 @@ public abstract class IndexedTensor implements Tensor { */ public static Builder of(TensorType type, DimensionSizes sizes) { validate(type, sizes); - switch (type.valueType()) { - case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); - case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - default: - throw new IllegalStateException("Unexpected value type " + type.valueType()); - } + return switch (type.valueType()) { + case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + default -> throw new IllegalStateException("Unexpected value type " + type.valueType()); + }; } /** @@ -348,14 +352,13 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, float[] values) { validate(type, sizes); validateSizes(sizes, values.length); - switch (type.valueType()) { - case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); - case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - default: - throw new IllegalStateException("Unexpected value type " + type.valueType()); - } + return switch (type.valueType()) { + case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); + case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + default -> throw new IllegalStateException("Unexpected value type " + type.valueType()); + }; } /** @@ -369,14 +372,13 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, double[] values) { validate(type, sizes); validateSizes(sizes, values.length); - switch (type.valueType()) { - case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); - case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - default: - throw new IllegalStateException("Unexpected value type " + type.valueType()); - } + return switch (type.valueType()) { + case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); + case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + default -> throw new IllegalStateException("Unexpected value type " + type.valueType()); + }; } private static void validateSizes(DimensionSizes sizes, int length) { @@ -518,7 +520,7 @@ public abstract class IndexedTensor implements Tensor { if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension for (long i = 0; i < currentDimension.size(); i++) fillValues(currentDimensionIndex + 1, - offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i, + offset + sizes.productOfDimensionsAfter(currentDimensionIndex) * i, (List<Object>) currentDimension.get((int)i), sizes, values); } else { // last dimension - fill values for (long i = 0; i < currentDimension.size(); i++) { @@ -623,11 +625,11 @@ public abstract class IndexedTensor implements Tensor { private final class ValueIterator implements Iterator<Double> { - private long count = 0; + private int count = 0; @Override public boolean hasNext() { - return count < size(); + return count < sizeAsInt(); } @Override @@ -889,8 +891,8 @@ public abstract class IndexedTensor implements Tensor { private static long computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) { long size = 1; - for (int iterateDimension : iterateDimensions) - size *= sizes.size(iterateDimension); + for (int i = 0; i < iterateDimensions.size(); i++) + size *= sizes.size(iterateDimensions.get(i)); return size; } @@ -1056,7 +1058,7 @@ public abstract class IndexedTensor implements Tensor { /** In this case we can reuse the source index computation for the iteration index */ private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes { - private long lastComputedSourceValueIndex = -1; + private long lastComputedSourceValueIndex = Tensor.invalidIndex; private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) { super(sizes, sizes, iterateDimensions, initialIndexes, size); @@ -1091,8 +1093,8 @@ public abstract class IndexedTensor implements Tensor { super(sourceSizes, iterateSizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; - this.sourceStep = productOfDimensionsAfter(iterateDimension, sourceSizes); - this.iterationStep = productOfDimensionsAfter(iterateDimension, iterateSizes); + this.sourceStep = sourceSizes.productOfDimensionsAfter(iterateDimension); + this.iterationStep = iterateSizes.productOfDimensionsAfter(iterateDimension); // Initialize to the (virtual) position before the first cell indexes[iterateDimension]--; @@ -1156,7 +1158,7 @@ public abstract class IndexedTensor implements Tensor { super(sizes, sizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; - this.step = productOfDimensionsAfter(iterateDimension, sizes); + this.step = sizes.productOfDimensionsAfter(iterateDimension); // Initialize to the (virtual) position before the first cell indexes[iterateDimension]--; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index e196569b18f..3e0df5f2261 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -6,7 +6,6 @@ import com.google.common.collect.ImmutableMap; import java.util.Iterator; import java.util.Map; import java.util.Set; -import java.util.function.DoubleBinaryOperator; /** * A sparse implementation of a tensor backed by a Map of cells to values. @@ -31,6 +30,10 @@ public class MappedTensor implements Tensor { @Override public long size() { return cells.size(); } + /** Once we can store more cells than an int we should drop this. */ + @Override + public int sizeAsInt() { return cells.size(); } + @Override public double get(TensorAddress address) { return cells.getOrDefault(address, 0.0); } @@ -38,6 +41,9 @@ public class MappedTensor implements Tensor { public boolean has(TensorAddress address) { return cells.containsKey(address); } @Override + public Double getAsDouble(TensorAddress address) { return cells.get(address); } + + @Override public Iterator<Cell> cellIterator() { return new CellIteratorAdaptor(cells.entrySet().iterator()); } @Override @@ -79,7 +85,7 @@ public class MappedTensor implements Tensor { @Override public String toAbbreviatedString(boolean withType, boolean shortForms) { - return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1))); + return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1))); } private String toString(boolean withType, boolean shortForms, long maxCells) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 5d5a5f74063..65c6677e7e3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -2,12 +2,13 @@ package com.yahoo.tensor; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -28,7 +29,6 @@ public class MixedTensor implements Tensor { /** The dimension specification for this tensor */ private final TensorType type; - private final int denseSubspaceSize; // XXX consider using "record" instead /** only exposed for internal use; subject to change without notice */ @@ -50,45 +50,15 @@ public class MixedTensor implements Tensor { } } - /** The cells in the tensor */ - private final List<DenseSubspace> denseSubspaces; - /** only exposed for internal use; subject to change without notice */ - public List<DenseSubspace> getInternalDenseSubspaces() { return denseSubspaces; } + public List<DenseSubspace> getInternalDenseSubspaces() { return index.denseSubspaces; } /** An index structure over the cell list */ private final Index index; - private MixedTensor(TensorType type, List<DenseSubspace> denseSubspaces, Index index) { + private MixedTensor(TensorType type, Index index) { this.type = type; - this.denseSubspaceSize = index.denseSubspaceSize(); - this.denseSubspaces = List.copyOf(denseSubspaces); this.index = index; - if (this.denseSubspaceSize < 1) { - throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize); - } - long count = 0; - for (var block : this.denseSubspaces) { - if (index.sparseMap.get(block.sparseAddress) != count) { - throw new IllegalStateException("map vs list mismatch: block #" - + count - + " address maps to #" - + index.sparseMap.get(block.sparseAddress)); - } - if (block.cells.length != denseSubspaceSize) { - throw new IllegalStateException("dense subspace size mismatch, expected " - + denseSubspaceSize - + " cells, but got: " - + block.cells.length); - } - ++count; - } - if (count != index.sparseMap.size()) { - throw new IllegalStateException("mismatch: list size is " - + count - + " but map size is " - + index.sparseMap.size()); - } } /** Returns the tensor type */ @@ -97,32 +67,34 @@ public class MixedTensor implements Tensor { /** Returns the size of the tensor measured in number of cells */ @Override - public long size() { return denseSubspaces.size() * denseSubspaceSize; } + public long size() { return index.denseSubspaces.size() * index.denseSubspaceSize; } /** Returns the value at the given address */ @Override public double get(TensorAddress address) { - int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > denseSubspaces.size()) { + var block = index.blockOf(address); + int denseOffset = index.denseOffsetOf(address); + if (block == null || denseOffset < 0 || denseOffset >= block.cells.length) { return 0.0; } + return block.cells[denseOffset]; + } + + @Override + public Double getAsDouble(TensorAddress address) { + var block = index.blockOf(address); int denseOffset = index.denseOffsetOf(address); - var block = denseSubspaces.get(blockNum); - if (denseOffset < 0 || denseOffset >= block.cells.length) { - return 0.0; + if (block == null || denseOffset < 0 || denseOffset >= block.cells.length) { + return null; } return block.cells[denseOffset]; } @Override public boolean has(TensorAddress address) { - int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > denseSubspaces.size()) { - return false; - } + var block = index.blockOf(address); int denseOffset = index.denseOffsetOf(address); - var block = denseSubspaces.get(blockNum); - return (denseOffset >= 0 && denseOffset < block.cells.length); + return (block != null && denseOffset >= 0 && denseOffset < block.cells.length); } /** @@ -135,21 +107,30 @@ public class MixedTensor implements Tensor { @Override public Iterator<Cell> cellIterator() { return new Iterator<>() { - final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator(); - DenseSubspace currBlock = null; - int currOffset = denseSubspaceSize; + + final Iterator<DenseSubspace> blockIterator = index.denseSubspaces.iterator(); + final int[] labels = new int[index.indexedDimensions.size()]; + DenseSubspace currentBlock = null; + int currOffset = index.denseSubspaceSize; + int prevOffset = -1; + @Override public boolean hasNext() { - return (currOffset < denseSubspaceSize || blockIterator.hasNext()); + return (currOffset < index.denseSubspaceSize || blockIterator.hasNext()); } + @Override public Cell next() { - if (currOffset == denseSubspaceSize) { - currBlock = blockIterator.next(); + if (currOffset == index.denseSubspaceSize) { + currentBlock = blockIterator.next(); currOffset = 0; } - TensorAddress fullAddr = index.fullAddressOf(currBlock.sparseAddress, currOffset); - double value = currBlock.cells[currOffset++]; + if (currOffset != prevOffset) { // Optimization for index.denseSubspaceSize == 1 + index.denseOffsetToAddress(currOffset, labels); + } + TensorAddress fullAddr = currentBlock.sparseAddress.fullAddressOf(index.type.dimensions(), labels); + prevOffset = currOffset; + double value = currentBlock.cells[currOffset++]; return new Cell(fullAddr, value); } }; @@ -162,20 +143,23 @@ public class MixedTensor implements Tensor { @Override public Iterator<Double> valueIterator() { return new Iterator<>() { - final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator(); - double[] currBlock = null; - int currOffset = denseSubspaceSize; + + final Iterator<DenseSubspace> blockIterator = index.denseSubspaces.iterator(); + double[] currentBlock = null; + int currOffset = index.denseSubspaceSize; + @Override public boolean hasNext() { - return (currOffset < denseSubspaceSize || blockIterator.hasNext()); + return (currOffset < index.denseSubspaceSize || blockIterator.hasNext()); } + @Override public Double next() { - if (currOffset == denseSubspaceSize) { - currBlock = blockIterator.next().cells; + if (currOffset == index.denseSubspaceSize) { + currentBlock = blockIterator.next().cells; currOffset = 0; } - return currBlock[currOffset++]; + return currentBlock[currOffset++]; } }; } @@ -197,24 +181,22 @@ public class MixedTensor implements Tensor { throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" + this.type + "', requested type: '" + type + "'"); } - return new MixedTensor(other, denseSubspaces, index); + return new MixedTensor(other, index); } @Override public Tensor remove(Set<TensorAddress> addresses) { var indexBuilder = new Index.Builder(type); - List<DenseSubspace> list = new ArrayList<>(); - for (var block : denseSubspaces) { + for (var block : index.denseSubspaces) { if ( ! addresses.contains(block.sparseAddress)) { // assumption: addresses only contain the sparse part - indexBuilder.addBlock(block.sparseAddress, list.size()); - list.add(block); + indexBuilder.addBlock(block); } } - return new MixedTensor(type, list, indexBuilder.build()); + return new MixedTensor(type, indexBuilder.build()); } @Override - public int hashCode() { return Objects.hash(type, denseSubspaces); } + public int hashCode() { return Objects.hash(type, index.denseSubspaces); } @Override public String toString() { @@ -249,13 +231,14 @@ public class MixedTensor implements Tensor { /** Returns the size of dense subspaces */ public long denseSubspaceSize() { - return denseSubspaceSize; + return index.denseSubspaceSize; } /** * Base class for building mixed tensors. */ public abstract static class Builder implements Tensor.Builder { + static final int INITIAL_HASH_CAPACITY = 1000; final TensorType type; @@ -265,10 +248,11 @@ public class MixedTensor implements Tensor { * a temporary structure while finding dimension bounds. */ public static Builder of(TensorType type) { - if (type.dimensions().stream().anyMatch(d -> d instanceof TensorType.IndexedUnboundDimension)) { - return new UnboundBuilder(type); + //TODO Wire in expected map size to avoid expensive resize + if (type.hasIndexedUnboundDimensions()) { + return new UnboundBuilder(type, INITIAL_HASH_CAPACITY); } else { - return new BoundBuilder(type); + return new BoundBuilder(type, INITIAL_HASH_CAPACITY); } } @@ -306,13 +290,14 @@ public class MixedTensor implements Tensor { public static class BoundBuilder extends Builder { /** For each sparse partial address, hold a dense subspace */ - private final Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>(); + private final Map<TensorAddress, double[]> denseSubspaceMap; private final Index.Builder indexBuilder; private final Index index; private final TensorType denseSubtype; - private BoundBuilder(TensorType type) { + private BoundBuilder(TensorType type, int expectedSize) { super(type); + denseSubspaceMap = new LinkedHashMap<>(expectedSize, 0.5f); indexBuilder = new Index.Builder(type); index = indexBuilder.index(); denseSubtype = new TensorType(type.valueType(), @@ -324,10 +309,7 @@ public class MixedTensor implements Tensor { } private double[] denseSubspace(TensorAddress sparseAddress) { - if (!denseSubspaceMap.containsKey(sparseAddress)) { - denseSubspaceMap.put(sparseAddress, new double[(int)denseSubspaceSize()]); - } - return denseSubspaceMap.get(sparseAddress); + return denseSubspaceMap.computeIfAbsent(sparseAddress, (key) -> new double[(int)denseSubspaceSize()]); } public IndexedTensor.DirectIndexBuilder denseSubspaceBuilder(TensorAddress sparseAddress) { @@ -343,7 +325,7 @@ public class MixedTensor implements Tensor { @Override public Tensor.Builder cell(TensorAddress address, double value) { - TensorAddress sparsePart = index.sparsePartialAddress(address); + TensorAddress sparsePart = address.mappedPartialAddress(index.sparseType, index.type.dimensions()); int denseOffset = index.denseOffsetOf(address); double[] denseSubspace = denseSubspace(sparsePart); denseSubspace[denseOffset] = value; @@ -362,19 +344,20 @@ public class MixedTensor implements Tensor { @Override public MixedTensor build() { - List<DenseSubspace> list = new ArrayList<>(); - for (Map.Entry<TensorAddress, double[]> entry : denseSubspaceMap.entrySet()) { + //TODO This can be solved more efficiently with a single map. + Set<Map.Entry<TensorAddress, double[]>> entrySet = denseSubspaceMap.entrySet(); + for (Map.Entry<TensorAddress, double[]> entry : entrySet) { TensorAddress sparsePart = entry.getKey(); double[] denseSubspace = entry.getValue(); var block = new DenseSubspace(sparsePart, denseSubspace); - indexBuilder.addBlock(sparsePart, list.size()); - list.add(block); + indexBuilder.addBlock(block); } - return new MixedTensor(type, list, indexBuilder.build()); + return new MixedTensor(type, indexBuilder.build()); } public static BoundBuilder of(TensorType type) { - return new BoundBuilder(type); + //TODO Wire in expected map size to avoid expensive resize + return new BoundBuilder(type, INITIAL_HASH_CAPACITY); } } @@ -391,9 +374,9 @@ public class MixedTensor implements Tensor { private final Map<TensorAddress, Double> cells; private final long[] dimensionBounds; - private UnboundBuilder(TensorType type) { + private UnboundBuilder(TensorType type, int expectedSize) { super(type); - cells = new HashMap<>(); + cells = new LinkedHashMap<>(expectedSize, 0.5f); dimensionBounds = new long[type.dimensions().size()]; } @@ -412,7 +395,7 @@ public class MixedTensor implements Tensor { @Override public MixedTensor build() { TensorType boundType = createBoundType(); - BoundBuilder builder = new BoundBuilder(boundType); + BoundBuilder builder = new BoundBuilder(boundType, cells.size()); for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) { builder.cell(cell.getKey(), cell.getValue()); } @@ -443,7 +426,8 @@ public class MixedTensor implements Tensor { } public static UnboundBuilder of(TensorType type) { - return new UnboundBuilder(type); + //TODO Wire in expected map size to avoid expensive resize + return new UnboundBuilder(type, INITIAL_HASH_CAPACITY); } } @@ -460,8 +444,10 @@ public class MixedTensor implements Tensor { private final TensorType denseType; private final List<TensorType.Dimension> mappedDimensions; private final List<TensorType.Dimension> indexedDimensions; + private final int[] indexedDimensionsSize; private ImmutableMap<TensorAddress, Integer> sparseMap; + private List<DenseSubspace> denseSubspaces; private final int denseSubspaceSize; static private int computeDSS(List<TensorType.Dimension> dimensions) { @@ -477,17 +463,31 @@ public class MixedTensor implements Tensor { this.type = type; this.mappedDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).toList(); this.indexedDimensions = type.dimensions().stream().filter(TensorType.Dimension::isIndexed).toList(); + this.indexedDimensionsSize = new int[indexedDimensions.size()]; + for (int i = 0; i < indexedDimensions.size(); i++) { + long dimensionSize = indexedDimensions.get(i).size().orElseThrow(() -> + new IllegalArgumentException("Unknown size of indexed dimension.")); + indexedDimensionsSize[i] = (int)dimensionSize; + } + this.sparseType = createPartialType(type.valueType(), mappedDimensions); this.denseType = createPartialType(type.valueType(), indexedDimensions); this.denseSubspaceSize = computeDSS(this.indexedDimensions); + if (this.denseSubspaceSize < 1) { + throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize); + } } - int blockIndexOf(TensorAddress address) { - TensorAddress sparsePart = sparsePartialAddress(address); - return sparseMap.getOrDefault(sparsePart, -1); + private DenseSubspace blockOf(TensorAddress address) { + TensorAddress sparsePart = address.mappedPartialAddress(sparseType, type.dimensions()); + Integer blockNum = sparseMap.get(sparsePart); + if (blockNum == null || blockNum >= denseSubspaces.size()) { + return null; + } + return denseSubspaces.get(blockNum); } - int denseOffsetOf(TensorAddress address) { + private int denseOffsetOf(TensorAddress address) { long innerSize = 1; long offset = 0; for (int i = type.dimensions().size(); --i >= 0; ) { @@ -506,54 +506,19 @@ public class MixedTensor implements Tensor { return denseSubspaceSize; } - private TensorAddress sparsePartialAddress(TensorAddress address) { - if (type.dimensions().size() != address.size()) - throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + address); - TensorAddress.Builder builder = new TensorAddress.Builder(sparseType); - for (int i = 0; i < type.dimensions().size(); ++i) { - TensorType.Dimension dimension = type.dimensions().get(i); - if ( ! dimension.isIndexed()) - builder.add(dimension.name(), address.label(i)); - } - return builder.build(); - } - - private TensorAddress denseOffsetToAddress(long denseOffset) { + private void denseOffsetToAddress(long denseOffset, int [] labels) { if (denseOffset < 0 || denseOffset > denseSubspaceSize) { throw new IllegalArgumentException("Offset out of bounds"); } long restSize = denseOffset; long innerSize = denseSubspaceSize; - long[] labels = new long[indexedDimensions.size()]; for (int i = 0; i < labels.length; ++i) { - TensorType.Dimension dimension = indexedDimensions.get(i); - long dimensionSize = dimension.size().orElseThrow(() -> - new IllegalArgumentException("Unknown size of indexed dimension.")); - - innerSize /= dimensionSize; - labels[i] = restSize / innerSize; + innerSize /= indexedDimensionsSize[i]; + labels[i] = (int) (restSize / innerSize); restSize %= innerSize; } - return TensorAddress.of(labels); - } - - TensorAddress fullAddressOf(TensorAddress sparsePart, long denseOffset) { - TensorAddress densePart = denseOffsetToAddress(denseOffset); - String[] labels = new String[type.dimensions().size()]; - int mappedIndex = 0; - int indexedIndex = 0; - for (TensorType.Dimension d : type.dimensions()) { - if (d.isIndexed()) { - labels[mappedIndex + indexedIndex] = densePart.label(indexedIndex); - indexedIndex++; - } else { - labels[mappedIndex + indexedIndex] = sparsePart.label(mappedIndex); - mappedIndex++; - } - } - return TensorAddress.of(labels); } @Override @@ -563,7 +528,7 @@ public class MixedTensor implements Tensor { private String contentToString(MixedTensor tensor, long maxCells) { if (mappedDimensions.size() > 1) throw new IllegalStateException("Should be ensured by caller"); - if (mappedDimensions.size() == 0) { + if (mappedDimensions.isEmpty()) { StringBuilder b = new StringBuilder(); int cellsWritten = denseSubspaceToString(tensor, 0, maxCells, b); if (cellsWritten == maxCells && cellsWritten < tensor.size()) @@ -605,8 +570,7 @@ public class MixedTensor implements Tensor { b.append(", "); // start brackets - for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) - b.append("["); + b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart()))); // value switch (type.valueType()) { @@ -619,32 +583,38 @@ public class MixedTensor implements Tensor { } // end bracket - for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) - b.append("]"); + b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd()))); } return index; } private double getDouble(int subspaceIndex, int denseOffset, MixedTensor tensor) { - return tensor.denseSubspaces.get(subspaceIndex).cells[denseOffset]; + return tensor.index.denseSubspaces.get(subspaceIndex).cells[denseOffset]; } - static class Builder { + private static class Builder { private final Index index; - private final ImmutableMap.Builder<TensorAddress, Integer> builder; + private final ImmutableMap.Builder<TensorAddress, Integer> builder = new ImmutableMap.Builder<>(); + private final ImmutableList.Builder<DenseSubspace> listBuilder = new ImmutableList.Builder<>(); + private int count = 0; Builder(TensorType type) { index = new Index(type); - builder = new ImmutableMap.Builder<>(); } - void addBlock(TensorAddress address, int sz) { - builder.put(address, sz); + void addBlock(DenseSubspace block) { + if (block.cells.length != index.denseSubspaceSize) { + throw new IllegalStateException("dense subspace size mismatch, expected " + index.denseSubspaceSize + + " cells, but got: " + block.cells.length); + } + builder.put(block.sparseAddress, count++); + listBuilder.add(block); } Index build() { index.sparseMap = builder.build(); + index.denseSubspaces = listBuilder.build(); return index; } @@ -654,27 +624,16 @@ public class MixedTensor implements Tensor { } } - private static class DenseSubspaceBuilder implements IndexedTensor.DirectIndexBuilder { - - private final TensorType type; - private final double[] values; - - public DenseSubspaceBuilder(TensorType type, double[] values) { - this.type = type; - this.values = values; - } - - @Override - public TensorType type() { return type; } + private record DenseSubspaceBuilder(TensorType type, double[] values) implements IndexedTensor.DirectIndexBuilder { @Override public void cellByDirectIndex(long index, double value) { - values[(int)index] = value; + values[(int) index] = value; } @Override public void cellByDirectIndex(long index, float value) { - values[(int)index] = value; + values[(int) index] = value; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java index f1b3245ec80..8852bcd1ff3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java @@ -1,16 +1,16 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; -import java.util.Arrays; +import com.yahoo.tensor.impl.Label; /** - * An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors + * An address to a subset of a tensors' cells, specifying a label for some, but not necessarily all, of the tensors * dimensions. * * @author bratseth */ // Implementation notes: -// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation. +// - These are created in inner (though not innermost) loops, so they are implemented with minimal allocation. // We also avoid non-essential error checking. // - We can add support for string labels later without breaking the API public class PartialAddress { @@ -18,7 +18,7 @@ public class PartialAddress { // Two arrays which contains corresponding dimension:label pairs. // The sizes of these are always equal. private final String[] dimensionNames; - private final Object[] labels; + private final long[] labels; private PartialAddress(Builder builder) { this.dimensionNames = builder.dimensionNames; @@ -35,15 +35,15 @@ public class PartialAddress { public long numericLabel(String dimensionName) { for (int i = 0; i < dimensionNames.length; i++) if (dimensionNames[i].equals(dimensionName)) - return asLong(labels[i]); - return -1; + return labels[i]; + return Tensor.invalidIndex; } /** Returns the label of this dimension, or null if no label is specified for it */ public String label(String dimensionName) { for (int i = 0; i < dimensionNames.length; i++) if (dimensionNames[i].equals(dimensionName)) - return labels[i].toString(); + return Label.fromNumber(labels[i]); return null; } @@ -55,7 +55,7 @@ public class PartialAddress { public String label(int i) { if (i >= size()) throw new IllegalArgumentException("No label at position " + i + " in " + this); - return labels[i].toString(); + return Label.fromNumber(labels[i]); } public int size() { return dimensionNames.length; } @@ -65,40 +65,14 @@ public class PartialAddress { public TensorAddress asAddress(TensorType type) { if (type.rank() != size()) throw new IllegalArgumentException(type + " has a different rank than " + this); - if (Arrays.stream(labels).allMatch(l -> l instanceof Long)) { - long[] numericLabels = new long[labels.length]; - for (int i = 0; i < type.dimensions().size(); i++) { - long label = numericLabel(type.dimensions().get(i).name()); - if (label < 0) - throw new IllegalArgumentException(type + " dimension names does not match " + this); - numericLabels[i] = label; - } - return TensorAddress.of(numericLabels); - } - else { - String[] stringLabels = new String[labels.length]; - for (int i = 0; i < type.dimensions().size(); i++) { - String label = label(type.dimensions().get(i).name()); - if (label == null) - throw new IllegalArgumentException(type + " dimension names does not match " + this); - stringLabels[i] = label; - } - return TensorAddress.of(stringLabels); - } - } - - private long asLong(Object label) { - if (label instanceof Long) { - return (Long) label; - } - else { - try { - return Long.parseLong(label.toString()); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Label '" + label + "' is not numeric"); - } + long[] numericLabels = new long[labels.length]; + for (int i = 0; i < type.dimensions().size(); i++) { + long label = numericLabel(type.dimensions().get(i).name()); + if (label == Tensor.invalidIndex) + throw new IllegalArgumentException(type + " dimension names does not match " + this); + numericLabels[i] = label; } + return TensorAddress.of(numericLabels); } @Override @@ -114,12 +88,12 @@ public class PartialAddress { public static class Builder { private String[] dimensionNames; - private Object[] labels; + private long[] labels; private int index = 0; public Builder(int size) { dimensionNames = new String[size]; - labels = new Object[size]; + labels = new long[size]; } public Builder add(String dimensionName, long label) { @@ -131,7 +105,7 @@ public class PartialAddress { public Builder add(String dimensionName, String label) { dimensionNames[index] = dimensionName; - labels[index] = label; + labels[index] = Label.toNumber(label); index++; return this; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 8a4179cdc1a..ac9dc4e4eca 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -20,6 +20,7 @@ import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.Softmax; import com.yahoo.tensor.functions.XwPlusB; import com.yahoo.tensor.functions.Expand; +import com.yahoo.tensor.impl.Label; import java.util.ArrayList; import java.util.Arrays; @@ -39,7 +40,7 @@ import static com.yahoo.tensor.functions.ScalarFunctions.Hamming; * A multidimensional array which can be used in computations. * <p> * A tensor consists of a set of <i>dimension</i> names and a set of <i>cells</i> containing scalar <i>values</i>. - * Each cell is is identified by its <i>address</i>, which consists of a set of dimension-label pairs which defines + * Each cell is identified by its <i>address</i>, which consists of a set of dimension-label pairs which defines * the location of that cell. Both dimensions and labels are string on the form of an identifier or integer. * <p> * The size of the set of dimensions of a tensor is called its <i>rank</i>. @@ -55,6 +56,9 @@ import static com.yahoo.tensor.functions.ScalarFunctions.Hamming; */ public interface Tensor { + /** The constant signaling a nonexisting value in operations addressing tensor values by index. */ + int invalidIndex = -1; + // ----------------- Accessors TensorType type(); @@ -63,11 +67,25 @@ public interface Tensor { default boolean isEmpty() { return size() == 0; } /** - * Returns the number of cells in this. - * TODO Figure how to best return an int instead of a long - * An int is large enough, and java is far better at int base loops than long - **/ - long size(); + * Returns the number of cells in this, allowing for very large tensors. + * Prefer sizeAsInt in implementations that cannot handle sizes outside the int range. + */ + default long size() { + return sizeAsInt(); + } + + /** + * Returns the size of this as an int or throws an exception if it is too large to fit in an int. + * Prefer this over size() with implementations that only handle sizes in the int range. + * + * @throws IndexOutOfBoundsException if the size is too large to fit in an int + */ + default int sizeAsInt() { + long size = size(); + if (size > Integer.MAX_VALUE) + throw new IndexOutOfBoundsException("size = " + size + ", which is too large to fit in an int"); + return (int) size; + } /** Returns the value of a cell, or 0.0 if this cell does not exist */ double get(TensorAddress address); @@ -75,6 +93,9 @@ public interface Tensor { /** Returns true if this cell exists */ boolean has(TensorAddress address); + /** Returns the value at this address, or null of it does not exist. */ + Double getAsDouble(TensorAddress address); + /** * Returns the cell of this in some undefined order. * A cell instances is only valid until next() is called. @@ -97,7 +118,7 @@ public interface Tensor { * @throws IllegalStateException if this does not have zero dimensions and one value */ default double asDouble() { - if (type().dimensions().size() > 0) + if (!type().dimensions().isEmpty()) throw new IllegalStateException("Require a dimensionless tensor but has " + type()); if (size() == 0) return Double.NaN; return valueIterator().next(); @@ -113,7 +134,7 @@ public interface Tensor { /** * Returns a new tensor where existing cells in this tensor have been * modified according to the given operation and cells in the given map. - * Cells in the map outside of existing cells are thus ignored. + * Cells in the map outside existing cells are thus ignored. * * @param op the modifying function * @param cells the cells to modify @@ -132,9 +153,9 @@ public interface Tensor { /** * Returns a new tensor where existing cells in this tensor have been - * removed according to the given set of addresses. Only valid for sparse + * removed according to the given set of addresses. Only valid for mapped * or mixed tensors. For mixed tensors, addresses are assumed to only - * contain the sparse dimensions, as the entire dense subspace is removed. + * contain the mapped dimensions, as the entire indexed subspace is removed. * * @param addresses list of addresses to remove * @return a new tensor where cells have been removed @@ -484,11 +505,10 @@ public interface Tensor { public TensorAddress getKey() { return address; } /** - * Returns the direct index which can be used to locate this cell, or -1 if not available. - * This is for optimizations mapping between tensors where this is possible without creating a - * TensorAddress. + * Returns the direct index which can be used to locate this cell, or Tensor.invalidIndex if not available. + * This is for optimizations mapping between tensors where this is possible without creating a TensorAddress. */ - long getDirectIndex() { return -1; } + long getDirectIndex() { return invalidIndex; } /** Returns the value as a double */ @Override @@ -537,8 +557,8 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type) { - boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); + boolean containsIndexed = type.hasIndexedDimensions(); + boolean containsMapped = type.hasMappedDimensions(); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) @@ -549,8 +569,8 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type, DimensionSizes dimensionSizes) { - boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); + boolean containsIndexed = type.hasIndexedDimensions(); + boolean containsMapped = type.hasMappedDimensions(); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) @@ -608,7 +628,7 @@ public interface Tensor { public TensorType type() { return tensorBuilder.type(); } public CellBuilder label(String dimension, long label) { - return label(dimension, String.valueOf(label)); + return label(dimension, Label.fromNumber(label)); } public Builder value(double cellValue) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index a1cb278c75a..4fa759668b6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -1,10 +1,13 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.yahoo.tensor.impl.Convert; +import com.yahoo.tensor.impl.Label; +import com.yahoo.tensor.impl.TensorAddressAny; + import java.util.Arrays; +import java.util.List; import java.util.Objects; -import java.util.Optional; -import java.util.stream.Collectors; /** * An immutable address to a tensor cell. This simply supplies a value to each dimension @@ -14,18 +17,20 @@ import java.util.stream.Collectors; */ public abstract class TensorAddress implements Comparable<TensorAddress> { - private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000); - public static TensorAddress of(String[] labels) { - return new StringTensorAddress(labels); + return TensorAddressAny.of(labels); + } + + public static TensorAddress ofLabels(String... labels) { + return TensorAddressAny.of(labels); } - public static TensorAddress ofLabels(String ... labels) { - return new StringTensorAddress(labels); + public static TensorAddress of(long... labels) { + return TensorAddressAny.of(labels); } - public static TensorAddress of(long ... labels) { - return new NumericTensorAddress(labels); + public static TensorAddress of(int... labels) { + return TensorAddressAny.of(labels); } /** Returns the number of labels in this */ @@ -61,27 +66,22 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { } @Override - public int hashCode() { - int result = 1; - for (int i = 0; i < size(); i++) { - if (label(i) != null) - result = 31 * result + label(i).hashCode(); + public String toString() { + StringBuilder sb = new StringBuilder("cell address ("); + int size = size(); + if (size > 0) { + sb.append(label(0)); + for (int i = 1; i < size; i++) { + sb.append(',').append(label(i)); + } } - return result; - } - @Override - public boolean equals(Object o) { - if (o == this) return true; - if ( ! (o instanceof TensorAddress other)) return false; - if (other.size() != this.size()) return false; - for (int i = 0; i < this.size(); i++) - if ( ! Objects.equals(this.label(i), other.label(i))) - return false; - return true; + return sb.append(')').toString(); } - /** Returns this as a string on the appropriate form given the type */ + /** + * Returns this as a string on the appropriate form given the type + */ public final String toString(TensorType type) { StringBuilder b = new StringBuilder("{"); for (int i = 0; i < size(); i++) { @@ -94,106 +94,78 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return b.toString(); } - /** Returns a label as a string with appropriate quoting/escaping when necessary */ + /** + * Returns a label as a string with appropriate quoting/escaping when necessary + */ public static String labelToString(String label) { if (TensorType.labelMatcher.matches(label)) return label; // no quoting if (label.contains("'")) return "\"" + label + "\""; return "'" + label + "'"; } - private static String[] createSmallIndexesAsStrings(int count) { - String [] asStrings = new String[count]; - for (int i = 0; i < count; i++) { - asStrings[i] = String.valueOf(i); + /** Returns an address with only some of the dimension. Ordering will also be according to indexMap */ + public TensorAddress partialCopy(int[] indexMap) { + int[] labels = new int[indexMap.length]; + for (int i = 0; i < labels.length; ++i) { + labels[i] = (int)numericLabel(indexMap[i]); } - return asStrings; + return TensorAddressAny.ofUnsafe(labels); } - private static String asString(long index) { - return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index); - } - - private static final class StringTensorAddress extends TensorAddress { - - private final String[] labels; - - private StringTensorAddress(String ... labels) { - this.labels = Arrays.copyOf(labels, labels.length); - } - - @Override - public int size() { return labels.length; } - - @Override - public String label(int i) { return labels[i]; } - - @Override - public long numericLabel(int i) { - try { - return Long.parseLong(labels[i]); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'"); + /** Creates a complete address by taking the mapped dimmensions from this and the indexed from the indexedPart */ + public TensorAddress fullAddressOf(List<TensorType.Dimension> dimensions, int[] densePart) { + int[] labels = new int[dimensions.size()]; + int mappedIndex = 0; + int indexedIndex = 0; + for (int i = 0; i < labels.length; i++) { + TensorType.Dimension d = dimensions.get(i); + if (d.isIndexed()) { + labels[i] = densePart[indexedIndex]; + indexedIndex++; + } else { + labels[i] = (int)numericLabel(mappedIndex); + mappedIndex++; } } - - @Override - public TensorAddress withLabel(int index, long label) { - String[] labels = Arrays.copyOf(this.labels, this.labels.length); - labels[index] = TensorAddress.asString(label); - return new StringTensorAddress(labels); - } - - - @Override - public String toString() { - return "cell address (" + String.join(",", labels) + ")"; - } - + return TensorAddressAny.ofUnsafe(labels); } - private static final class NumericTensorAddress extends TensorAddress { - - private final long[] labels; - - private NumericTensorAddress(long[] labels) { - this.labels = Arrays.copyOf(labels, labels.length); - } - - @Override - public int size() { return labels.length; } - - @Override - public String label(int i) { return TensorAddress.asString(labels[i]); } - - @Override - public long numericLabel(int i) { return labels[i]; } - - @Override - public TensorAddress withLabel(int index, long label) { - long[] labels = Arrays.copyOf(this.labels, this.labels.length); - labels[index] = label; - return new NumericTensorAddress(labels); - } - - @Override - public String toString() { - return "cell address (" + Arrays.stream(labels).mapToObj(TensorAddress::asString).collect(Collectors.joining(",")) + ")"; + /** + * Returns an address containing the mapped dimensions of this. + * + * @param mappedType the type of the mapped subset of the type this is an address of; + * which is also the type of the returned address + * @param dimensions all the dimensions of the type this is an address of + */ + public TensorAddress mappedPartialAddress(TensorType mappedType, List<TensorType.Dimension> dimensions) { + if (dimensions.size() != size()) + throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + this); + TensorAddress.Builder builder = new TensorAddress.Builder(mappedType); + for (int i = 0; i < dimensions.size(); ++i) { + TensorType.Dimension dimension = dimensions.get(i); + if ( ! dimension.isIndexed()) + builder.add(dimension.name(), (int)numericLabel(i)); } - + return builder.build(); } /** Builder of a tensor address */ public static class Builder { final TensorType type; - final String[] labels; + final int[] labels; + + private static int[] createEmptyLabels(int size) { + int[] labels = new int[size]; + Arrays.fill(labels, Tensor.invalidIndex); + return labels; + } public Builder(TensorType type) { - this(type, new String[type.dimensions().size()]); + this(type, createEmptyLabels(type.dimensions().size())); } - private Builder(TensorType type, String[] labels) { + private Builder(TensorType type, int[] labels) { this.type = type; this.labels = labels; } @@ -207,7 +179,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { var mappedSubtype = type.mappedSubtype(); if (mappedSubtype.rank() != 1) throw new IllegalArgumentException("Cannot add a label without explicit dimension to a tensor of type " + - type + ": Must have exactly one sparse dimension"); + type + ": Must have exactly one mapped dimension"); add(mappedSubtype.dimensions().get(0).name(), label); return this; } @@ -220,10 +192,22 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { public Builder add(String dimension, String label) { Objects.requireNonNull(dimension, "dimension cannot be null"); Objects.requireNonNull(label, "label cannot be null"); - Optional<Integer> labelIndex = type.indexOfDimension(dimension); - if ( labelIndex.isEmpty()) + int labelIndex = type.indexOfDimensionAsInt(dimension); + if ( labelIndex < 0) + throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'"); + labels[labelIndex] = Label.toNumber(label); + return this; + } + + public Builder add(String dimension, long label) { + return add(dimension, Convert.safe2Int(label)); + } + public Builder add(String dimension, int label) { + Objects.requireNonNull(dimension, "dimension cannot be null"); + int labelIndex = type.indexOfDimensionAsInt(dimension); + if ( labelIndex < 0) throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'"); - labels[labelIndex.get()] = label; + labels[labelIndex] = label; return this; } @@ -237,14 +221,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { void validate() { for (int i = 0; i < labels.length; i++) - if (labels[i] == null) + if (labels[i] == Tensor.invalidIndex) throw new IllegalArgumentException("Missing a label for dimension '" + type.dimensions().get(i).name() + "' for " + type); } public TensorAddress build() { validate(); - return TensorAddress.of(labels); + return TensorAddressAny.ofUnsafe(labels); } } @@ -256,7 +240,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { super(type); } - private PartialBuilder(TensorType type, String[] labels) { + private PartialBuilder(TensorType type, int[] labels) { super(type, labels); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index b30b664a5f7..6b81d023a9a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.google.common.collect.ImmutableSet; import com.yahoo.text.Ascii7BitMatcher; import java.util.ArrayList; @@ -86,16 +87,20 @@ public class TensorType { /** Sorted list of the dimensions of this */ private final List<Dimension> dimensions; + private final Set<String> dimensionNames; private final TensorType mappedSubtype; private final TensorType indexedSubtype; + private final int indexedUnBoundCount; // only used to initialize the "empty" instance private TensorType() { this.valueType = Value.DOUBLE; this.dimensions = List.of(); + this.dimensionNames = Set.of(); this.mappedSubtype = this; this.indexedSubtype = this; + indexedUnBoundCount = 0; } public TensorType(Value valueType, Collection<Dimension> dimensions) { @@ -103,12 +108,25 @@ public class TensorType { List<Dimension> dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); this.dimensions = List.copyOf(dimensionList); + ImmutableSet.Builder<String> namesbuilder = new ImmutableSet.Builder<>(); + int indexedBoundCount = 0, indexedUnBoundCount = 0, mappedCount = 0; + for (Dimension dimension : dimensionList) { + namesbuilder.add(dimension.name()); + Dimension.Type type = dimension.type(); + switch (type) { + case indexedUnbound -> indexedUnBoundCount++; + case indexedBound -> indexedBoundCount++; + case mapped -> mappedCount++; + } + } + this.indexedUnBoundCount = indexedUnBoundCount; + dimensionNames = namesbuilder.build(); - if (dimensionList.stream().allMatch(Dimension::isIndexed)) { + if (mappedCount == 0) { mappedSubtype = empty; indexedSubtype = this; } - else if (dimensionList.stream().noneMatch(Dimension::isIndexed)) { + else if ((indexedBoundCount + indexedUnBoundCount) == 0) { mappedSubtype = this; indexedSubtype = empty; } @@ -118,6 +136,11 @@ public class TensorType { } } + public boolean hasIndexedDimensions() { return indexedSubtype != empty; } + public boolean hasMappedDimensions() { return mappedSubtype != empty; } + public boolean hasOnlyIndexedBoundDimensions() { return !hasMappedDimensions() && ! hasIndexedUnboundDimensions(); } + boolean hasIndexedUnboundDimensions() { return indexedUnBoundCount > 0; } + static public Value combinedValueType(TensorType ... types) { List<Value> valueTypes = new ArrayList<>(); for (TensorType type : types) { @@ -161,7 +184,7 @@ public class TensorType { /** Returns an immutable set of the names of the dimensions of this */ public Set<String> dimensionNames() { - return dimensions.stream().map(Dimension::name).collect(Collectors.toSet()); + return dimensionNames; } /** Returns the dimension with this name, or empty if not present */ @@ -176,6 +199,13 @@ public class TensorType { return Optional.of(i); return Optional.empty(); } + /** Returns the 0-base index of this dimension, or empty if it is not present */ + public int indexOfDimensionAsInt(String dimension) { + for (int i = 0; i < dimensions.size(); i++) + if (dimensions.get(i).name().equals(dimension)) + return i; + return Tensor.invalidIndex; + } /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */ public Optional<Long> sizeOfDimension(String dimension) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 0e4fab95c87..9125b35ea5d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.impl.TensorAddressAny; import java.util.Arrays; import java.util.HashMap; @@ -133,7 +134,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return tensor; } else { // extend tensor with this dimension - if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) + if (tensor.type().hasMappedDimensions()) throw new IllegalArgumentException("Concat requires an indexed tensor, " + "but got a tensor with type " + tensor.type()); Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType) @@ -172,7 +173,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType concatType, long concatOffset, String concatDimension) { long[] combinedLabels = new long[concatType.dimensions().size()]; - Arrays.fill(combinedLabels, -1); + Arrays.fill(combinedLabels, Tensor.invalidIndex); int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get(); mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here @@ -191,7 +192,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET private int[] mapIndexes(TensorType fromType, TensorType toType) { int[] toIndexes = new int[fromType.dimensions().size()]; for (int i = 0; i < fromType.dimensions().size(); i++) - toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); + toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(Tensor.invalidIndex); return toIndexes; } @@ -208,7 +209,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET to[toIndex] = from.numericLabel(i) + concatOffset; } else { - if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false; + if (to[toIndex] != Tensor.invalidIndex && to[toIndex] != from.numericLabel(i)) return false; to[toIndex] = from.numericLabel(i); } } @@ -354,21 +355,21 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) { - String[] labels = new String[plan.resultType.rank()]; + int[] labels = new int[plan.resultType.rank()]; int out = 0; int m = 0; int a = 0; int b = 0; for (var how : plan.combineHow) { switch (how) { - case left -> labels[out++] = leftOnly.label(a++); - case right -> labels[out++] = rightOnly.label(b++); - case both -> labels[out++] = match.label(m++); - case concat -> labels[out++] = String.valueOf(concatDimIdx); + case left -> labels[out++] = (int) leftOnly.numericLabel(a++); + case right -> labels[out++] = (int) rightOnly.numericLabel(b++); + case both -> labels[out++] = (int) match.numericLabel(m++); + case concat -> labels[out++] = concatDimIdx; default -> throw new IllegalArgumentException("cannot handle: " + how); } } - return TensorAddress.of(labels); + return TensorAddressAny.ofUnsafe(labels); } Tensor merge(CellVectorMapMap a, CellVectorMapMap b) { @@ -398,8 +399,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET CellVectorMapMap decompose(Tensor input, SplitHow how) { var iter = input.cellIterator(); - String[] commonLabels = new String[(int)how.numCommon()]; - String[] separateLabels = new String[(int)how.numSeparate()]; + int[] commonLabels = new int[(int)how.numCommon()]; + int[] separateLabels = new int[(int)how.numSeparate()]; CellVectorMapMap result = new CellVectorMapMap(); while (iter.hasNext()) { var cell = iter.next(); @@ -409,14 +410,14 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET int separateIdx = 0; for (int i = 0; i < how.handleDims.size(); i++) { switch (how.handleDims.get(i)) { - case common -> commonLabels[commonIdx++] = addr.label(i); - case separate -> separateLabels[separateIdx++] = addr.label(i); + case common -> commonLabels[commonIdx++] = (int) addr.numericLabel(i); + case separate -> separateLabels[separateIdx++] = (int) addr.numericLabel(i); case concat -> ccDimIndex = addr.numericLabel(i); default -> throw new IllegalArgumentException("cannot handle: " + how.handleDims.get(i)); } } - TensorAddress commonAddr = TensorAddress.of(commonLabels); - TensorAddress separateAddr = TensorAddress.of(separateLabels); + TensorAddress commonAddr = TensorAddressAny.ofUnsafe(commonLabels); + TensorAddress separateAddr = TensorAddressAny.ofUnsafe(separateLabels); result.lookupCreate(commonAddr).lookupCreate(separateAddr).setValue((int)ccDimIndex, cell.getValue()); } return result; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 3b6e03186a3..b595b1a40cd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -40,7 +40,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens @Override public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if (arguments.size() != 0) + if (!arguments.isEmpty()) throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size()); return this; } @@ -79,7 +79,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() { var result = new ArrayList<TensorFunction<NAMETYPE>>(); for (var fun : cells.values()) { - fun.asTensorFunction().ifPresent(tf -> result.add(tf)); + fun.asTensorFunction().ifPresent(result::add); } return result; } @@ -133,7 +133,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens IndexedDynamicTensor(TensorType type, List<ScalarFunction<NAMETYPE>> cells) { super(type); - if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) + if ( ! type.hasOnlyIndexedBoundDimensions()) throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " + "only indexed, bound dimensions, but this has " + type); this.cells = List.copyOf(cells); @@ -142,7 +142,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() { var result = new ArrayList<TensorFunction<NAMETYPE>>(); for (var fun : cells) { - fun.asTensorFunction().ifPresent(tf -> result.add(tf)); + fun.asTensorFunction().ifPresent(result::add); } return result; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 4c92e1e57a2..fb345264f56 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -12,8 +12,11 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.impl.Convert; +import com.yahoo.tensor.impl.TensorAddressAny; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -113,7 +116,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) { - long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); + int joinedRank = (int)Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator<Double> aIterator = a.valueIterator(); Iterator<Double> bIterator = b.valueIterator(); IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build()); @@ -128,8 +131,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); var key = aCell.getKey(); - if (b.has(key)) { - builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + Double bVal = b.getAsDouble(key); + if (bVal != null) { + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal)); } } return builder.build(); @@ -144,7 +148,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } private static Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) { - if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes + if (subspace.isEmpty() || superspace.isEmpty()) // special case empty here to avoid doing it when finding sizes return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build(); DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace); @@ -169,7 +173,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Iterator<Tensor.Cell> superspace, long superspaceSize, boolean reversedArgumentOrder, IndexedTensor.Builder builder, DoubleBinaryOperator combinator) { - long joinedLength = Math.min(subspaceSize, superspaceSize); + int joinedLength = (int)Math.min(subspaceSize, superspaceSize); if (reversedArgumentOrder) { for (int i = 0; i < joinedLength; i++) { Tensor.Cell supercell = superspace.next(); @@ -204,12 +208,13 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> supercell = i.next(); - TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); - if (subspace.has(subaddress)) { - double subspaceValue = subspace.get(subaddress); + TensorAddress subaddress = supercell.getKey().partialCopy(subspaceIndexes); + Double subspaceValue = subspace.getAsDouble(subaddress); + if (subspaceValue != null) { builder.cell(supercell.getKey(), - reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) - : combinator.applyAsDouble(subspaceValue, supercell.getValue())); + reversedArgumentOrder + ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) + : combinator.applyAsDouble(subspaceValue, supercell.getValue())); } } return builder.build(); @@ -223,13 +228,6 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP return subspaceIndexes; } - private static TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { - String[] subspaceLabels = new String[subspaceIndexes.length]; - for (int i = 0; i < subspaceIndexes.length; i++) - subspaceLabels[i] = superAddress.label(subspaceIndexes[i]); - return TensorAddress.of(subspaceLabels); - } - /** Slow join which works for any two tensors */ private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { if (a instanceof IndexedTensor && b instanceof IndexedTensor) @@ -250,8 +248,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder, DoubleBinaryOperator combinator) { - Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()); - Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()); + Set<String> sharedDimensions = Set.copyOf(Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames())); + int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection + Set<String> dimensionsOnlyInA = Set.copyOf(Sets.difference(a.type().dimensionNames(), b.type().dimensionNames())); DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize); DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize); @@ -262,7 +261,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP // for each combination of dimensions in a which is also in b while (aSubspace.hasNext()) { Tensor.Cell aCell = aSubspace.next(); - PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions); + PartialAddress matchingBCells = sharedDimensionSize > 0 + ? partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize) + : empty; // for each matching combination of dimensions ony in b for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) { Tensor.Cell bCell = bSubspace.next(); @@ -274,11 +275,15 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } } - private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) { - PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); - for (int i = 0; i < addressType.dimensions().size(); i++) - if (retainDimensions.contains(addressType.dimensions().get(i).name())) - builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i)); + private static final PartialAddress empty = new PartialAddress.Builder(0).build(); + private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, + Set<String> retainDimensions, int sharedDimensionSize) { + PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize); + for (int i = 0; i < addressType.dimensions().size(); i++) { + String dimension = addressType.dimensions().get(i).name(); + if (retainDimensions.contains(dimension)) + builder.add(dimension, address.numericLabel(i)); + } return builder.build(); } @@ -330,19 +335,18 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP int[] bIndexesInJoined = mapIndexes(b.type(), joinedType); // Iterate once through the smaller tensor and construct a hash map for common dimensions - Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(); + Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt()); for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) { Tensor.Cell aCell = cellIterator.next(); - TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon); - aCellsByCommonAddress.putIfAbsent(partialCommonAddress, new ArrayList<>()); - aCellsByCommonAddress.get(partialCommonAddress).add(aCell); + TensorAddress partialCommonAddress = aCell.getKey().partialCopy(aIndexesInCommon); + aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell); } // Iterate once through the larger tensor and use the hash map to find joinable cells Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Tensor.Cell> cellIterator = b.cellIterator(); cellIterator.hasNext(); ) { Tensor.Cell bCell = cellIterator.next(); - TensorAddress partialCommonAddress = partialCommonAddress(bCell, bIndexesInCommon); + TensorAddress partialCommonAddress = bCell.getKey().partialCopy(bIndexesInCommon); for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, List.of())) { TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined, bCell.getKey(), bIndexesInJoined, joinedType); @@ -358,7 +362,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } /** - * Returns the an array having one entry in order for each dimension of fromType + * Returns an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) @@ -367,17 +371,18 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP static int[] mapIndexes(TensorType fromType, TensorType toType) { int[] toIndexes = new int[fromType.dimensions().size()]; for (int i = 0; i < fromType.dimensions().size(); i++) - toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); + toIndexes[i] = toType.indexOfDimensionAsInt(fromType.dimensions().get(i).name()); return toIndexes; } private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType joinedType) { - String[] joinedLabels = new String[joinedType.dimensions().size()]; + int[] joinedLabels = new int[joinedType.dimensions().size()]; + Arrays.fill(joinedLabels, Tensor.invalidIndex); mapContent(a, joinedLabels, aToIndexes); boolean compatible = mapContent(b, joinedLabels, bToIndexes); if ( ! compatible) return null; - return TensorAddress.of(joinedLabels); + return TensorAddressAny.ofUnsafe(joinedLabels); } /** @@ -386,11 +391,13 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP * @return true if the mapping was successful, false if one of the destination positions was * occupied by a different value */ - private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) { - for (int i = 0; i < from.size(); i++) { + private static boolean mapContent(TensorAddress from, int[] to, int[] indexMap) { + for (int i = 0, size = from.size(); i < size; i++) { int toIndex = indexMap[i]; - if (to[toIndex] != null && ! to[toIndex].equals(from.label(i))) return false; - to[toIndex] = from.label(i); + int label = Convert.safe2Int(from.numericLabel(i)); + if (to[toIndex] != Tensor.invalidIndex && to[toIndex] != label) + return false; + to[toIndex] = label; } return true; } @@ -412,14 +419,5 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP return typeBuilder.build(); } - private static TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) { - TensorAddress address = cell.getKey(); - String[] labels = new String[indexMap.length]; - for (int i = 0; i < labels.length; ++i) { - labels[i] = address.label(indexMap[i]); - } - return TensorAddress.of(labels); - } - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java index c87ef42976d..aa9602339e9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java @@ -98,9 +98,9 @@ public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction for (int i = 0; i < inputType.dimensions().size(); i++) { var dim = inputType.dimensions().get(i); if (dim.isMapped()) { - mapAddrBuilder.add(dim.name(), fullAddr.label(i)); + mapAddrBuilder.add(dim.name(), fullAddr.numericLabel(i)); } else { - idxAddrBuilder.add(dim.name(), fullAddr.label(i)); + idxAddrBuilder.add(dim.name(), fullAddr.numericLabel(i)); } } var mapAddr = mapAddrBuilder.build(); @@ -123,11 +123,11 @@ public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction var addrBuilder = new TensorAddress.Builder(outputType); for (int i = 0; i < inputTypeMapped.dimensions().size(); i++) { var dim = inputTypeMapped.dimensions().get(i); - addrBuilder.add(dim.name(), mappedAddr.label(i)); + addrBuilder.add(dim.name(), mappedAddr.numericLabel(i)); } for (int i = 0; i < denseOutputDims.size(); i++) { var dim = denseOutputDims.get(i); - addrBuilder.add(dim.name(), denseAddr.label(i)); + addrBuilder.add(dim.name(), denseAddr.numericLabel(i)); } builder.cell(addrBuilder.build(), cell.getValue()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java index 59394785382..ddad91dc060 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java @@ -121,10 +121,11 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); var key = aCell.getKey(); - if (! b.has(key)) { + Double bVal = b.getAsDouble(key); + if (bVal == null) { builder.cell(key, aCell.getValue()); } else if (combinator != null) { - builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal)); } } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 8cf88610599..947fd6e0012 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -1,6 +1,8 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.DimensionSizes; +import com.yahoo.tensor.DirectIndexedAddress; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; @@ -9,16 +11,15 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.impl.Convert; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; /** * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions @@ -112,32 +113,84 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } static Tensor evaluate(Tensor argument, List<String> dimensions, Aggregator aggregator) { - if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) + if (!dimensions.isEmpty() && !argument.type().dimensionNames().containsAll(dimensions)) throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + - dimensions + ": Not all those dimensions are present in this tensor"); + dimensions + ": Not all those dimensions are present in this tensor"); // Special case: Reduce all - if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) + if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) { if (argument.isEmpty()) return Tensor.from(0.0); else if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor) - return reduceIndexedVector((IndexedTensor)argument, aggregator); + return reduceIndexedVector((IndexedTensor) argument, aggregator); else return reduceAllGeneral(argument, aggregator); + } TensorType reducedType = outputType(argument.type(), dimensions); + int[] indexesToReduce = createIndexesToReduce(argument.type(), dimensions); + int[] indexesToKeep = createIndexesToKeep(argument.type(), indexesToReduce); + if (argument instanceof IndexedTensor indexedTensor && reducedType.hasOnlyIndexedBoundDimensions()) { + return reduceIndexedTensor(indexedTensor, reducedType, indexesToKeep, indexesToReduce, aggregator); + } else { + return reduceGeneral(argument, reducedType, indexesToKeep, aggregator); + } + } + + private static void reduce(IndexedTensor argument, ValueAggregator aggregator, DirectIndexedAddress address, int[] reduce, int reduceIndex) { + int currentIndex = reduce[reduceIndex]; + int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex)); + if (reduceIndex + 1 < reduce.length) { + int nextDimension = reduceIndex + 1; + for (int i = 0; i < dimSize; i++) { + address.setIndex(currentIndex, i); + reduce(argument, aggregator, address, reduce, nextDimension); + } + } else { + address.setIndex(currentIndex, 0); + long increment = address.getStride(currentIndex); + long directIndex = address.getDirectIndex(); + for (int i = 0; i < dimSize; i++) { + aggregator.aggregate(argument.get(directIndex + i * increment)); + } + } + } + + private static void reduce(IndexedTensor.Builder builder, DirectIndexedAddress destAddress, IndexedTensor argument, Aggregator aggregator, DirectIndexedAddress address, int[] toKeep, int keepIndex, int[] toReduce) { + if (keepIndex < toKeep.length) { + int currentIndex = toKeep[keepIndex]; + int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex)); + + int nextKeep = keepIndex + 1; + for (int i = 0; i < dimSize; i++) { + address.setIndex(currentIndex, i); + destAddress.setIndex(keepIndex, i); + reduce(builder, destAddress, argument, aggregator, address, toKeep, nextKeep, toReduce); + } + } else { + ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); + reduce(argument, valueAggregator, address, toReduce, 0); + builder.cell(valueAggregator.aggregatedValue(), destAddress.getIndexes()); + } + + } - // Reduce cells - int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions); + private static Tensor reduceIndexedTensor(IndexedTensor argument, TensorType reducedType, int[] indexesToKeep, int[] indexesToReduce, Aggregator aggregator) { + + var reducedBuilder = IndexedTensor.Builder.of(reducedType); + DirectIndexedAddress reducedAddress = DirectIndexedAddress.of(DimensionSizes.of(reducedType)); + reduce(reducedBuilder, reducedAddress, argument, aggregator, argument.directAddress(), indexesToKeep, 0, indexesToReduce); + return reducedBuilder.build(); + } + + private static Tensor reduceGeneral(Tensor argument, TensorType reducedType, int[] indexesToKeep, Aggregator aggregator) { // TODO cells.size() is most likely an overestimate, and might need a better heuristic // But the upside is larger than the downside. - Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>((int)argument.size()); + Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(argument.sizeAsInt()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); - TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey()); - ValueAggregator aggr = aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); - if (aggr == null) - aggr = aggregatingCells.get(reducedAddress); + TensorAddress reducedAddress = cell.getKey().partialCopy(indexesToKeep); + ValueAggregator aggr = aggregatingCells.computeIfAbsent(reducedAddress, (key) ->ValueAggregator.ofType(aggregator)); aggr.aggregate(cell.getValue()); } Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); @@ -146,39 +199,43 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return reducedBuilder.build(); } - private static int[] createIndexesToKeep(TensorType argumentType, List<String> dimensions) { - Set<Integer> indexesToRemove = new HashSet<>(dimensions.size()*2); - for (String dimensionToRemove : dimensions) - indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get()); - int[] indexesToKeep = new int[argumentType.rank() - indexesToRemove.size()]; + + private static int[] createIndexesToReduce(TensorType tensorType, List<String> dimensions) { + int[] indexesToReduce = new int[dimensions.size()]; + for (int i = 0; i < dimensions.size(); i++) { + indexesToReduce[i] = tensorType.indexOfDimension(dimensions.get(i)).get(); + } + return indexesToReduce; + } + private static int[] createIndexesToKeep(TensorType argumentType, int[] indexesToReduce) { + int[] indexesToKeep = new int[argumentType.rank() - indexesToReduce.length]; int toKeepIndex = 0; for (int i = 0; i < argumentType.rank(); i++) { - if ( ! indexesToRemove.contains(i)) + if ( ! contains(indexesToReduce, i)) indexesToKeep[toKeepIndex++] = i; } return indexesToKeep; } - - private static TensorAddress reduceDimensions(int[] indexesToKeep, TensorAddress address) { - String[] reducedLabels = new String[indexesToKeep.length]; - int reducedLabelIndex = 0; - for (int toKeep : indexesToKeep) - reducedLabels[reducedLabelIndex++] = address.label(toKeep); - return TensorAddress.of(reducedLabels); + private static boolean contains(int[] list, int key) { + for (int candidate : list) { + if (candidate == key) return true; + } + return false; } private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); ) valueAggregator.aggregate(i.next()); - return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); + return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build(); } private static Tensor reduceIndexedVector(IndexedTensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); - for (int i = 0; i < argument.dimensionSizes().size(0); i++) + int dimensionSize = Convert.safe2Int(argument.dimensionSizes().size(0)); + for (int i = 0; i < dimensionSize ; i++) valueAggregator.aggregate(argument.get(i)); - return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); + return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build(); } static abstract class ValueAggregator { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index aece782d296..2d5a0518747 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -92,11 +92,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N return false; if ( ! (a instanceof IndexedTensor)) return false; - if ( ! (a.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))) + if ( ! (a.type().hasOnlyIndexedBoundDimensions())) return false; if ( ! (b instanceof IndexedTensor)) return false; - if ( ! (b.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))) + if ( ! (b.type().hasOnlyIndexedBoundDimensions())) return false; TensorType commonDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index a2a3874eced..05db61f5395 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -35,7 +35,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null"); Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null"); - if (fromDimensions.size() < 1) + if (fromDimensions.isEmpty()) throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension"); if (fromDimensions.size() != toDimensions.size()) throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " + @@ -89,7 +89,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET for (int i = 0; i < tensor.type().dimensions().size(); i++) { String dimensionName = tensor.type().dimensions().get(i).name(); String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName); - toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get(); + toIndexes[renamedType.indexOfDimension(newDimensionName).get()] = i; } // avoid building a new tensor if dimensions can simply be renamed @@ -100,7 +100,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET Tensor.Builder builder = Tensor.Builder.of(renamedType); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); - TensorAddress renamedAddress = rename(cell.getKey(), toIndexes); + TensorAddress renamedAddress = cell.getKey().partialCopy(toIndexes); builder.cell(renamedAddress, cell.getValue()); } return builder.build(); @@ -118,13 +118,6 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return true; } - private TensorAddress rename(TensorAddress address, int[] toIndexes) { - String[] reorderedLabels = new String[toIndexes.length]; - for (int i = 0; i < toIndexes.length; i++) - reorderedLabels[toIndexes[i]] = address.label(i); - return TensorAddress.of(reorderedLabels); - } - private String toVectorString(List<String> elements) { if (elements.size() == 1) return elements.get(0); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index 807f56b1a49..38ac42a5f1f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -131,7 +131,7 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY for (int i = 0; i < address.size(); i++) { String dimension = type.dimensions().get(i).name(); if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent()) - b.add(dimension, address.label(i)); + b.add(dimension, (int)address.numericLabel(i)); } return b.build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java new file mode 100644 index 00000000000..e2cb64fdd1f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java @@ -0,0 +1,16 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.impl; + +/** + * Utility to make common conversions safe + * + * @author baldersheim + */ +public class Convert { + public static int safe2Int(long value) { + if (value > Integer.MAX_VALUE || value < Integer.MIN_VALUE) { + throw new IndexOutOfBoundsException("value = " + value + ", which is too large to fit in an int"); + } + return (int) value; + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java new file mode 100644 index 00000000000..7c1e8646245 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java @@ -0,0 +1,83 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.Tensor; + +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A label is a value of a mapped dimension of a tensor. + * This class provides a mapping of labels to numbers which allow for more efficient computation with + * mapped tensor dimensions. + * + * @author baldersheim + */ +public class Label { + + private static final String[] SMALL_INDEXES = createSmallIndexesAsStrings(1000); + + private final static Map<String, Integer> string2Enum = new ConcurrentHashMap<>(); + + // Index 0 is unused, that is a valid positive number + // 1(-1) is reserved for the Tensor.INVALID_INDEX + private static volatile String[] uniqueStrings = {"UNIQUE_UNUSED_MAGIC", "Tensor.INVALID_INDEX"}; + private static int numUniqeStrings = 2; + + private static String[] createSmallIndexesAsStrings(int count) { + String[] asStrings = new String[count]; + for (int i = 0; i < count; i++) { + asStrings[i] = String.valueOf(i); + } + return asStrings; + } + + private static int addNewUniqueString(String s) { + synchronized (string2Enum) { + if (numUniqeStrings >= uniqueStrings.length) { + uniqueStrings = Arrays.copyOf(uniqueStrings, uniqueStrings.length*2); + } + uniqueStrings[numUniqeStrings] = s; + return -numUniqeStrings++; + } + } + + private static String asNumericString(long index) { + return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index); + } + + private static boolean validNumericIndex(String s) { + if (s.isEmpty() || ((s.length() > 1) && (s.charAt(0) == '0'))) return false; + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + if ((c < '0') || (c > '9')) return false; + } + return true; + } + + public static int toNumber(String s) { + if (s == null) { return Tensor.invalidIndex; } + try { + if (validNumericIndex(s)) { + return Integer.parseInt(s); + } + } catch (NumberFormatException e) { + } + return string2Enum.computeIfAbsent(s, Label::addNewUniqueString); + } + + public static String fromNumber(int v) { + if (v >= 0) { + return asNumericString(v); + } else { + if (v == Tensor.invalidIndex) { return null; } + return uniqueStrings[-v]; + } + } + + public static String fromNumber(long v) { + return fromNumber(Convert.safe2Int(v)); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java new file mode 100644 index 00000000000..2e70811a67c --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java @@ -0,0 +1,154 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; + +import static com.yahoo.tensor.impl.Convert.safe2Int; +import static com.yahoo.tensor.impl.Label.toNumber; +import static com.yahoo.tensor.impl.Label.fromNumber; + +/** + * Parent of tensor address family centered around each dimension as int. + * A positive number represents a numeric index usable as a direect addressing. + * - 1 is representing an invalid/null address + * Other negative numbers are an enumeration maintained in {@link Label} + * + * @author baldersheim + */ +abstract public class TensorAddressAny extends TensorAddress { + + @Override + public String label(int i) { + return fromNumber((int)numericLabel(i)); + } + + public static TensorAddress of() { + return TensorAddressEmpty.empty; + } + + public static TensorAddress of(String label) { + return new TensorAddressAny1(toNumber(label)); + } + + public static TensorAddress of(String label0, String label1) { + return new TensorAddressAny2(toNumber(label0), toNumber(label1)); + } + + public static TensorAddress of(String label0, String label1, String label2) { + return new TensorAddressAny3(toNumber(label0), toNumber(label1), toNumber(label2)); + } + + public static TensorAddress of(String label0, String label1, String label2, String label3) { + return new TensorAddressAny4(toNumber(label0), toNumber(label1), toNumber(label2), toNumber(label3)); + } + + public static TensorAddress of(String[] labels) { + int[] labelsAsInt = new int[labels.length]; + for (int i = 0; i < labels.length; i++) { + labelsAsInt[i] = toNumber(labels[i]); + } + return ofUnsafe(labelsAsInt); + } + + public static TensorAddress of(int label) { + return new TensorAddressAny1(sanitize(label)); + } + + public static TensorAddress of(int label0, int label1) { + return new TensorAddressAny2(sanitize(label0), sanitize(label1)); + } + + public static TensorAddress of(int label0, int label1, int label2) { + return new TensorAddressAny3(sanitize(label0), sanitize(label1), sanitize(label2)); + } + + public static TensorAddress of(int label0, int label1, int label2, int label3) { + return new TensorAddressAny4(sanitize(label0), sanitize(label1), sanitize(label2), sanitize(label3)); + } + + public static TensorAddress of(int ... labels) { + return switch (labels.length) { + case 0 -> of(); + case 1 -> new TensorAddressAny1(sanitize(labels[0])); + case 2 -> new TensorAddressAny2(sanitize(labels[0]), sanitize(labels[1])); + case 3 -> new TensorAddressAny3(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2])); + case 4 -> new TensorAddressAny4(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]), sanitize(labels[3])); + default -> { + for (int i = 0; i < labels.length; i++) { + sanitize(labels[i]); + } + yield new TensorAddressAnyN(labels); + } + }; + } + + public static TensorAddress of(long label) { + return of(safe2Int(label)); + } + + public static TensorAddress of(long label0, long label1) { + return of(safe2Int(label0), safe2Int(label1)); + } + + public static TensorAddress of(long label0, long label1, long label2) { + return of(safe2Int(label0), safe2Int(label1), safe2Int(label2)); + } + + public static TensorAddress of(long label0, long label1, long label2, long label3) { + return of(safe2Int(label0), safe2Int(label1), safe2Int(label2), safe2Int(label3)); + } + + public static TensorAddress of(long ... labels) { + return switch (labels.length) { + case 0 -> of(); + case 1 -> ofUnsafe(safe2Int(labels[0])); + case 2 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1])); + case 3 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2])); + case 4 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]), safe2Int(labels[3])); + default -> { + int[] labelsAsInt = new int[labels.length]; + for (int i = 0; i < labels.length; i++) { + labelsAsInt[i] = safe2Int(labels[i]); + } + yield of(labelsAsInt); + } + }; + } + + private static TensorAddress ofUnsafe(int label) { + return new TensorAddressAny1(label); + } + + private static TensorAddress ofUnsafe(int label0, int label1) { + return new TensorAddressAny2(label0, label1); + } + + private static TensorAddress ofUnsafe(int label0, int label1, int label2) { + return new TensorAddressAny3(label0, label1, label2); + } + + private static TensorAddress ofUnsafe(int label0, int label1, int label2, int label3) { + return new TensorAddressAny4(label0, label1, label2, label3); + } + + public static TensorAddress ofUnsafe(int ... labels) { + return switch (labels.length) { + case 0 -> of(); + case 1 -> ofUnsafe(labels[0]); + case 2 -> ofUnsafe(labels[0], labels[1]); + case 3 -> ofUnsafe(labels[0], labels[1], labels[2]); + case 4 -> ofUnsafe(labels[0], labels[1], labels[2], labels[3]); + default -> new TensorAddressAnyN(labels); + }; + } + + private static int sanitize(int label) { + if (label < Tensor.invalidIndex) { + throw new IndexOutOfBoundsException("cell label " + label + " must be positive"); + } + return label; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java new file mode 100644 index 00000000000..a9be6173781 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java @@ -0,0 +1,41 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +/** + * A one-dimensional address. + * + * @author baldersheim + */ +final class TensorAddressAny1 extends TensorAddressAny { + + private final int label; + + TensorAddressAny1(int label) { this.label = label; } + + @Override public int size() { return 1; } + + @Override + public long numericLabel(int i) { + if (i == 0) { + return label; + } + throw new IndexOutOfBoundsException("Index is not zero: " + i); + } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + if (labelIndex == 0) return new TensorAddressAny1(Convert.safe2Int(label)); + throw new IllegalArgumentException("No label " + labelIndex); + } + + @Override public int hashCode() { return Math.abs(label); } + + @Override + public boolean equals(Object o) { + return (o instanceof TensorAddressAny1 any) && (label == any.label); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java new file mode 100644 index 00000000000..43f65d495cf --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java @@ -0,0 +1,53 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import static java.lang.Math.abs; + +/** + * A two-dimensional address. + * + * @author baldersheim + */ +final class TensorAddressAny2 extends TensorAddressAny { + + private final int label0, label1; + + TensorAddressAny2(int label0, int label1) { + this.label0 = label0; + this.label1 = label1; + } + + @Override public int size() { return 2; } + + @Override + public long numericLabel(int i) { + return switch (i) { + case 0 -> label0; + case 1 -> label1; + default -> throw new IndexOutOfBoundsException("Index is not in [0,1]: " + i); + }; + } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + return switch (labelIndex) { + case 0 -> new TensorAddressAny2(Convert.safe2Int(label), label1); + case 1 -> new TensorAddressAny2(label0, Convert.safe2Int(label)); + default -> throw new IllegalArgumentException("No label " + labelIndex); + }; + } + + @Override + public int hashCode() { + return abs(label0) | (abs(label1) << 32 - Integer.numberOfLeadingZeros(abs(label0))); + } + + @Override + public boolean equals(Object o) { + return (o instanceof TensorAddressAny2 any) && (label0 == any.label0) && (label1 == any.label1); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java new file mode 100644 index 00000000000..c22ff47b3c4 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java @@ -0,0 +1,61 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import static java.lang.Math.abs; + +/** + * A three-dimensional address. + * + * @author baldersheim + */ +final class TensorAddressAny3 extends TensorAddressAny { + + private final int label0, label1, label2; + + TensorAddressAny3(int label0, int label1, int label2) { + this.label0 = label0; + this.label1 = label1; + this.label2 = label2; + } + + @Override public int size() { return 3; } + + @Override + public long numericLabel(int i) { + return switch (i) { + case 0 -> label0; + case 1 -> label1; + case 2 -> label2; + default -> throw new IndexOutOfBoundsException("Index is not in [0,2]: " + i); + }; + } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + return switch (labelIndex) { + case 0 -> new TensorAddressAny3(Convert.safe2Int(label), label1, label2); + case 1 -> new TensorAddressAny3(label0, Convert.safe2Int(label), label2); + case 2 -> new TensorAddressAny3(label0, label1, Convert.safe2Int(label)); + default -> throw new IllegalArgumentException("No label " + labelIndex); + }; + } + + @Override + public int hashCode() { + return abs(label0) | + (abs(label1) << (1*32 - Integer.numberOfLeadingZeros(abs(label0)))) | + (abs(label2) << (2*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1))))); + } + + @Override + public boolean equals(Object o) { + return (o instanceof TensorAddressAny3 any) && + (label0 == any.label0) && + (label1 == any.label1) && + (label2 == any.label2); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java new file mode 100644 index 00000000000..6eb6b9216bf --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java @@ -0,0 +1,66 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import static java.lang.Math.abs; + +/** + * A four-dimensional address. + * + * @author baldersheim + */ +final class TensorAddressAny4 extends TensorAddressAny { + + private final int label0, label1, label2, label3; + + TensorAddressAny4(int label0, int label1, int label2, int label3) { + this.label0 = label0; + this.label1 = label1; + this.label2 = label2; + this.label3 = label3; + } + + @Override public int size() { return 4; } + + @Override + public long numericLabel(int i) { + return switch (i) { + case 0 -> label0; + case 1 -> label1; + case 2 -> label2; + case 3 -> label3; + default -> throw new IndexOutOfBoundsException("Index is not in [0,3]: " + i); + }; + } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + return switch (labelIndex) { + case 0 -> new TensorAddressAny4(Convert.safe2Int(label), label1, label2, label3); + case 1 -> new TensorAddressAny4(label0, Convert.safe2Int(label), label2, label3); + case 2 -> new TensorAddressAny4(label0, label1, Convert.safe2Int(label), label3); + case 3 -> new TensorAddressAny4(label0, label1, label2, Convert.safe2Int(label)); + default -> throw new IllegalArgumentException("No label " + labelIndex); + }; + } + + @Override + public int hashCode() { + return abs(label0) | + (abs(label1) << (1*32 - Integer.numberOfLeadingZeros(abs(label0)))) | + (abs(label2) << (2*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1))))) | + (abs(label3) << (3*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1)) + Integer.numberOfLeadingZeros(abs(label1))))); + } + + @Override + public boolean equals(Object o) { + return (o instanceof TensorAddressAny4 any) && + (label0 == any.label0) && + (label1 == any.label1) && + (label2 == any.label2) && + (label3 == any.label3); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java new file mode 100644 index 00000000000..d5bac62bf18 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java @@ -0,0 +1,53 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import java.util.Arrays; + +import static java.lang.Math.abs; + +/** + * An n-dimensional address. + * + * @author baldersheim + */ +final class TensorAddressAnyN extends TensorAddressAny { + + private final int[] labels; + + TensorAddressAnyN(int[] labels) { + if (labels.length < 1) throw new IllegalArgumentException("Need at least 1 label"); + this.labels = labels; + } + + @Override public int size() { return labels.length; } + + @Override public long numericLabel(int i) { return labels[i]; } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + int[] copy = Arrays.copyOf(labels, labels.length); + copy[labelIndex] = Convert.safe2Int(label); + return new TensorAddressAnyN(copy); + } + + @Override public int hashCode() { + int hash = abs(labels[0]); + for (int i = 0; i < size(); i++) { + hash = hash | (abs(labels[i]) << (32 - Integer.numberOfLeadingZeros(hash))); + } + return hash; + } + + @Override + public boolean equals(Object o) { + if (! (o instanceof TensorAddressAnyN any) || (size() != any.size())) return false; + for (int i = 0; i < size(); i++) { + if (labels[i] != any.labels[i]) return false; + } + return true; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java new file mode 100644 index 00000000000..eb7e62e913b --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java @@ -0,0 +1,33 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +/** + * A zero-dimensional address. + * + * @author baldersheim + */ +final class TensorAddressEmpty extends TensorAddressAny { + + static TensorAddress empty = new TensorAddressEmpty(); + + private TensorAddressEmpty() {} + + @Override public int size() { return 0; } + + @Override public long numericLabel(int i) { throw new IllegalArgumentException("Empty address with no labels"); } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + throw new IllegalArgumentException("No label " + labelIndex); + } + + @Override + public int hashCode() { return 0; } + + @Override + public boolean equals(Object o) { return o instanceof TensorAddressEmpty; } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java new file mode 100644 index 00000000000..6b004bf2d02 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java @@ -0,0 +1,6 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +@ExportPackage +package com.yahoo.tensor.impl; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index ca9527fd681..32e74c0f132 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -56,22 +56,22 @@ public class DenseBinaryFormat implements BinaryFormat { } private void encodeDoubleCells(IndexedTensor tensor, GrowableByteBuffer buffer) { - for (int i = 0; i < tensor.size(); i++) + for (int i = 0; i < tensor.sizeAsInt(); i++) buffer.putDouble(tensor.get(i)); } private void encodeFloatCells(IndexedTensor tensor, GrowableByteBuffer buffer) { - for (int i = 0; i < tensor.size(); i++) + for (int i = 0; i < tensor.sizeAsInt(); i++) buffer.putFloat(tensor.getFloat(i)); } private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { - for (int i = 0; i < tensor.size(); i++) + for (int i = 0; i < tensor.sizeAsInt(); i++) buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(tensor.getFloat(i))); } private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { - for (int i = 0; i < tensor.size(); i++) + for (int i = 0; i < tensor.sizeAsInt(); i++) buffer.put((byte) tensor.getFloat(i)); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 444ce02b14a..5598690e0bf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -16,15 +16,7 @@ import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.Name; -import com.yahoo.tensor.functions.ConstantTensor; -import com.yahoo.tensor.functions.Slice; - -import java.util.ArrayList; -import java.util.HashSet; import java.util.Iterator; -import java.util.List; -import java.util.Set; /** * Writes tensors on the JSON format used in Vespa tensor document fields: @@ -60,8 +52,7 @@ public class JsonFormat { // Short form for a single mapped dimension Cursor parent = root == null ? slime.setObject() : root.setObject("cells"); encodeSingleDimensionCells((MappedTensor) tensor, parent); - } else if (tensor instanceof MixedTensor && - tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped)) { + } else if (tensor instanceof MixedTensor && tensor.type().hasMappedDimensions()) { // Short form for a mixed tensor boolean singleMapped = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() == 1; Cursor parent = root == null ? ( singleMapped ? slime.setObject() : slime.setArray() ) @@ -143,9 +134,9 @@ public class JsonFormat { } private static void encodeBlocks(MixedTensor tensor, Cursor cursor) { - var mappedDimensions = tensor.type().dimensions().stream().filter(d -> d.isMapped()) + var mappedDimensions = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped) .map(d -> TensorType.Dimension.mapped(d.name())).toList(); - if (mappedDimensions.size() < 1) { + if (mappedDimensions.isEmpty()) { throw new IllegalArgumentException("Should be ensured by caller"); } @@ -179,23 +170,6 @@ public class JsonFormat { cursor.setDouble(field, value); } - private static TensorAddress subAddress(TensorAddress address, TensorType subType, TensorType origType) { - TensorAddress.Builder builder = new TensorAddress.Builder(subType); - for (TensorType.Dimension dim : subType.dimensions()) { - builder.add(dim.name(), address.label(origType.indexOfDimension(dim.name()). - orElseThrow(() -> new IllegalStateException("Could not find mapped dimension index")))); - } - return builder.build(); - } - - private static Tensor sliceSubAddress(Tensor tensor, TensorAddress subAddress, TensorType subType) { - List<Slice.DimensionValue<Name>> sliceDims = new ArrayList<>(subAddress.size()); - for (int i = 0; i < subAddress.size(); ++i) { - sliceDims.add(new Slice.DimensionValue<>(subType.dimensions().get(i).name(), subAddress.label(i))); - } - return new Slice<>(new ConstantTensor<>(tensor), sliceDims).evaluate(); - } - /** Deserializes the given tensor from JSON format */ // NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module public static Tensor decode(TensorType type, byte[] jsonTensorValue) { @@ -204,7 +178,7 @@ public class JsonFormat { if (root.field("cells").valid() && ! primitiveContent(root.field("cells"))) decodeCells(root.field("cells"), builder); - else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed())) + else if (root.field("values").valid() && ! builder.type().hasMappedDimensions()) decodeValuesAtTop(root.field("values"), builder); else if (root.field("blocks").valid()) decodeBlocks(root.field("blocks"), builder); @@ -298,14 +272,14 @@ public class JsonFormat { /** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */ private static void decodeDirectValue(Inspector root, Tensor.Builder builder) { - boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); + boolean hasIndexed = builder.type().hasIndexedDimensions(); + boolean hasMapped = builder.type().hasMappedDimensions(); if (isArrayOfObjects(root)) decodeCells(root, builder); else if ( ! hasMapped) decodeValuesAtTop(root, builder); - else if (hasMapped && hasIndexed) + else if (hasIndexed) decodeBlocks(root, builder); else decodeCells(root, builder); @@ -423,9 +397,7 @@ public class JsonFormat { if (decoded.length == 0) { throw new IllegalArgumentException("The block value string does not contain any values"); } - for (int i = 0; i < decoded.length; i++) { - values[i] = decoded[i]; - } + System.arraycopy(decoded, 0, values, 0, decoded.length); } else { throw new IllegalArgumentException("Expected a block to contain an array of values"); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index bdeb9add41a..3a117e41461 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -48,7 +48,7 @@ class SparseBinaryFormat implements BinaryFormat { } private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { - buffer.putInt1_4Bytes((int)tensor.size()); // XXX: Size truncation + buffer.putInt1_4Bytes(tensor.sizeAsInt()); // XXX: Size truncation switch (serializationValueType) { case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index d4b18c73f11..0a5c713f3e2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -55,8 +55,8 @@ public class TypedBinaryFormat { } private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) { - boolean hasMappedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); - boolean hasIndexedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); + boolean hasMappedDimensions = tensor.type().hasMappedDimensions(); + boolean hasIndexedDimensions = tensor.type().hasIndexedDimensions(); boolean isMixed = hasMappedDimensions && hasIndexedDimensions; // TODO: Encoding as indexed if the implementation is mixed is not yet supported so use mixed format instead diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java index 0a6c821e64e..528ca57d256 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java @@ -9,6 +9,7 @@ import java.util.Iterator; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -45,12 +46,7 @@ public class IndexedTensorTestCase { @Test public void testNegativeLabels() { - TensorAddress numeric = TensorAddress.of(-1, 0, 1, 1234567, -1234567); - assertEquals("-1", numeric.label(0)); - assertEquals("0", numeric.label(1)); - assertEquals("1", numeric.label(2)); - assertEquals("1234567", numeric.label(3)); - assertEquals("-1234567", numeric.label(4)); + assertThrows(IndexOutOfBoundsException.class, () ->TensorAddress.of(-1, 0, 1, 1234567, -1234567)); } private void verifyFloat(String spec) { @@ -96,6 +92,38 @@ public class IndexedTensorTestCase { } @Test + public void testDirectIndexedAddress() { + TensorType type = new TensorType.Builder().indexed("v", 3) + .indexed("w", wSize) + .indexed("x", xSize) + .indexed("y", ySize) + .indexed("z", zSize) + .build(); + var directAddress = DirectIndexedAddress.of(DimensionSizes.of(type)); + assertThrows(ArrayIndexOutOfBoundsException.class, () -> directAddress.getStride(5)); + assertThrows(IndexOutOfBoundsException.class, () -> directAddress.setIndex(4, 7)); + assertEquals(wSize*xSize*ySize*zSize, directAddress.getStride(0)); + assertEquals(xSize*ySize*zSize, directAddress.getStride(1)); + assertEquals(ySize*zSize, directAddress.getStride(2)); + assertEquals(zSize, directAddress.getStride(3)); + assertEquals(1, directAddress.getStride(4)); + assertEquals(0, directAddress.getDirectIndex()); + directAddress.setIndex(0,1); + assertEquals(1 * directAddress.getStride(0), directAddress.getDirectIndex()); + directAddress.setIndex(1,1); + assertEquals(1 * directAddress.getStride(0) + 1 * directAddress.getStride(1), directAddress.getDirectIndex()); + directAddress.setIndex(2,2); + directAddress.setIndex(3,2); + directAddress.setIndex(4,2); + long expected = 1 * directAddress.getStride(0) + + 1 * directAddress.getStride(1) + + 2 * directAddress.getStride(2) + + 2 * directAddress.getStride(3) + + 2 * directAddress.getStride(4); + assertEquals(expected, directAddress.getDirectIndex()); + } + + @Test public void testUnboundBuilding() { TensorType type = new TensorType.Builder().indexed("w") .indexed("v") diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java new file mode 100644 index 00000000000..dd40e3105bf --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java @@ -0,0 +1,83 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor; + +import static com.yahoo.tensor.TensorAddress.of; +import static com.yahoo.tensor.TensorAddress.ofLabels; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +/** + * Test for tensor address. + * + * @author baldersheim + */ +public class TensorAddressTestCase { + public static void equal(TensorAddress a, TensorAddress b) { + assertEquals(a.hashCode(), b.hashCode()); + assertEquals(a, b); + assertEquals(a.size(), b.size()); + for (int i = 0; i < a.size(); i++) { + assertEquals(a.label(i), b.label(i)); + assertEquals(a.numericLabel(i), b.numericLabel(i)); + } + } + public static void notEqual(TensorAddress a, TensorAddress b) { + assertNotEquals(a.hashCode(), b.hashCode()); // This might not hold, but is bad if not very rare + assertNotEquals(a, b); + } + @Test + void testStringVersusNumericAddressEquality() { + equal(ofLabels("0"), of(0)); + equal(ofLabels("1"), of(1)); + } + @Test + void testInEquality() { + notEqual(ofLabels("1"), ofLabels("2")); + notEqual(of(1), of(2)); + notEqual(ofLabels("1"), ofLabels("01")); + notEqual(ofLabels("0"), ofLabels("00")); + } + @Test + void testDimensionsEffectsEqualityAndHash() { + notEqual(ofLabels("1"), ofLabels("1", "1")); + notEqual(of(1), of(1, 1)); + } + @Test + void testAllowNullDimension() { + TensorAddress s1 = ofLabels("1", null, "2"); + TensorAddress s2 = ofLabels("1", "2"); + assertNotEquals(s1, s2); + assertEquals(-1, s1.numericLabel(1)); + assertNull(s1.label(1)); + } + + private static void verifyWithLabel(int dimensions) { + int [] indexes = new int[dimensions]; + Arrays.fill(indexes, 1); + TensorAddress next = of(indexes); + for (int i = 0; i < dimensions; i++) { + indexes[i] = 3; + assertEquals(of(indexes), next = next.withLabel(i, 3)); + } + } + @Test + void testWithLabel() { + for (int i=0; i < 10; i++) { + verifyWithLabel(i); + } + } + + @Test + void testPartialCopy() { + var abcd = ofLabels("a", "b", "c", "d"); + int[] o_1_3_2 = {1,3,2}; + equal(ofLabels("b", "d", "c"), abcd.partialCopy(o_1_3_2)); + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index 5c4d5f1ffcf..91880c9af93 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -9,8 +9,10 @@ import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; -import java.util.*; -import java.util.stream.Collectors; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + /** * Microbenchmark of tensor operations. @@ -71,7 +73,7 @@ public class TensorFunctionBenchmark { for (int i = 0; i < vectorCount; i++) { Tensor.Builder builder = Tensor.Builder.of(type); for (int j = 0; j < vectorSize; j++) { - builder.cell().label("x", String.valueOf(j)).value(random.nextDouble()); + builder.cell().label("x", j).value(random.nextDouble()); } tensors.add(builder.build()); } @@ -86,12 +88,12 @@ public class TensorFunctionBenchmark { for (int i = 0; i < vectorCount; i++) { for (int j = 0; j < vectorSize; j++) { builder.cell() - .label("i", String.valueOf(i)) - .label("x", String.valueOf(j)) + .label("i", i) + .label("x", j) .value(random.nextDouble()); } } - return Collections.singletonList(builder.build()); + return List.of(builder.build()); } private static TensorType vectorType(TensorType.Builder builder, String name, TensorType.Dimension.Type type, int size) { @@ -107,45 +109,53 @@ public class TensorFunctionBenchmark { public static void main(String[] args) { double time = 0; - // ---------------- Mapped with extra space (sidesteps current special-case optimizations): - // 7.8 ms - time = new TensorFunctionBenchmark().benchmark(1000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); - System.out.printf("Mapped vectors, x space time per join: %1$8.3f ms\n", time); - // 7.7 ms - time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); - System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time); + // ---------------- Indexed unbound: + + time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); + System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time); + time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); + System.out.printf("Indexed unbound matrix, time per join: %1$8.3f ms\n", time); + + // ---------------- Indexed bound: + time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); + System.out.printf("Indexed bound vectors, time per join: %1$8.3f ms\n", time); + + time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); + System.out.printf("Indexed bound matrix, time per join: %1$8.3f ms\n", time); // ---------------- Mapped: - // 2.1 ms time = new TensorFunctionBenchmark().benchmark(5000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false); System.out.printf("Mapped vectors, time per join: %1$8.3f ms\n", time); - // 7.0 ms + time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false); System.out.printf("Mapped matrix, time per join: %1$8.3f ms\n", time); // ---------------- Indexed (unbound) with extra space (sidesteps current special-case optimizations): - // 14.5 ms time = new TensorFunctionBenchmark().benchmark(500, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed vectors, x space time per join: %1$8.3f ms\n", time); - // 8.9 ms + time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed matrix, x space time per join: %1$8.3f ms\n", time); - // ---------------- Indexed unbound: - // 0.14 ms - time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); - System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time); - // 0.44 ms - time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); - System.out.printf("Indexed unbound matrix, time per join: %1$8.3f ms\n", time); + // ---------------- Mapped with extra space (sidesteps current special-case optimizations): + time = new TensorFunctionBenchmark().benchmark(1000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); + System.out.printf("Mapped vectors, x space time per join: %1$8.3f ms\n", time); - // ---------------- Indexed bound: - // 0.32 ms - time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); - System.out.printf("Indexed bound vectors, time per join: %1$8.3f ms\n", time); - // 0.44 ms - time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); - System.out.printf("Indexed bound matrix, time per join: %1$8.3f ms\n", time); + time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); + System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time); + + /* 2.4Ghz Intel Core i9, Macbook Pro 2019 + Indexed unbound vectors, time per join: 0,066 ms + Indexed unbound matrix, time per join: 0,108 ms + Indexed bound vectors, time per join: 0,068 ms + Indexed bound matrix, time per join: 0,106 ms + Mapped vectors, time per join: 0,845 ms + Mapped matrix, time per join: 1,779 ms + Indexed vectors, x space time per join: 5,778 ms + Indexed matrix, x space time per join: 3,342 ms + Mapped vectors, x space time per join: 8,184 ms + Mapped matrix, x space time per join: 11,547 ms + */ } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java index 7cf0bd35b38..85619dca16c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java @@ -33,7 +33,7 @@ public class DynamicTensorTestCase { public void testDynamicMappedRank1TensorFunction() { TensorType sparse = TensorType.fromSpec("tensor(x{})"); DynamicTensor<Name> t2 = DynamicTensor.from(sparse, - Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(), + java.util.Map.of(new TensorAddress.Builder(sparse).add("x", "a").build(), new Constant(5))); assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate()); assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString()); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/impl/TensorAddressAnyTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/impl/TensorAddressAnyTestCase.java new file mode 100644 index 00000000000..18ff1f6a1d3 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/impl/TensorAddressAnyTestCase.java @@ -0,0 +1,35 @@ +package com.yahoo.tensor.impl; + +import static com.yahoo.tensor.impl.TensorAddressAny.of; +import static com.yahoo.tensor.TensorAddressTestCase.equal; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +/** + * @author baldersheim + */ +public class TensorAddressAnyTestCase { + + @Test + void testSize() { + for (int i = 0; i < 10; i++) { + int[] indexes = new int[i]; + assertEquals(i, of(indexes).size()); + } + } + + @Test + void testNumericStringEquality() { + for (int i = 0; i < 10; i++) { + int[] numericIndexes = new int[i]; + String[] stringIndexes = new String[i]; + for (int j = 0; j < i; j++) { + numericIndexes[j] = j; + stringIndexes[j] = String.valueOf(j); + } + equal(of(stringIndexes), of(numericIndexes)); + } + } + +} diff --git a/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp b/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp index 2f09e331c5d..c07743554ea 100644 --- a/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp +++ b/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp @@ -40,12 +40,14 @@ struct MyInvCmp { struct Timer { double minTime; - vespalib::Timer timer; - Timer() : minTime(1.0e10), timer() {} - void start() { timer = vespalib::Timer(); } + vespalib::steady_time start_time; + Timer() : minTime(1.0e10), start_time() {} + void start() { + start_time = vespalib::steady_clock::now(); + } void stop() { - double ms = vespalib::count_ms(timer.elapsed()); - minTime = std::min(minTime, ms); + std::chrono::duration<double,std::milli> elapsed = vespalib::steady_clock::now() - start_time; + minTime = std::min(minTime, elapsed.count()); } }; diff --git a/vespalib/src/vespa/fastlib/text/unicodeutil.cpp b/vespalib/src/vespa/fastlib/text/unicodeutil.cpp index e29b91d6522..bd4ff5d93a9 100644 --- a/vespalib/src/vespa/fastlib/text/unicodeutil.cpp +++ b/vespalib/src/vespa/fastlib/text/unicodeutil.cpp @@ -11,15 +11,15 @@ namespace { class Initialize { public: - Initialize() { Fast_UnicodeUtil::InitTables(); } + Initialize() noexcept { Fast_UnicodeUtil::InitTables(); } }; -Initialize _G_Initializer; +Initialize _g_initializer; } void -Fast_UnicodeUtil::InitTables() +Fast_UnicodeUtil::InitTables() noexcept { /** * Hack for Katakana accent marks (torgeir) @@ -29,8 +29,7 @@ Fast_UnicodeUtil::InitTables() } char * -Fast_UnicodeUtil::utf8ncopy(char *dst, const ucs4_t *src, - int maxdst, int maxsrc) +Fast_UnicodeUtil::utf8ncopy(char *dst, const ucs4_t *src, int maxdst, int maxsrc) noexcept { char * p = dst; char * edst = dst + maxdst; @@ -83,7 +82,7 @@ Fast_UnicodeUtil::utf8ncopy(char *dst, const ucs4_t *src, int -Fast_UnicodeUtil::utf8cmp(const char *s1, const ucs4_t *s2) +Fast_UnicodeUtil::utf8cmp(const char *s1, const ucs4_t *s2) noexcept { ucs4_t i1; ucs4_t i2; @@ -101,7 +100,7 @@ Fast_UnicodeUtil::utf8cmp(const char *s1, const ucs4_t *s2) } size_t -Fast_UnicodeUtil::ucs4strlen(const ucs4_t *str) +Fast_UnicodeUtil::ucs4strlen(const ucs4_t *str) noexcept { const ucs4_t *p = str; while (*p++ != 0) { @@ -111,7 +110,7 @@ Fast_UnicodeUtil::ucs4strlen(const ucs4_t *str) } ucs4_t * -Fast_UnicodeUtil::ucs4copy(ucs4_t *dst, const char *src) +Fast_UnicodeUtil::ucs4copy(ucs4_t *dst, const char *src) noexcept { ucs4_t i; ucs4_t *p; @@ -127,7 +126,7 @@ Fast_UnicodeUtil::ucs4copy(ucs4_t *dst, const char *src) } ucs4_t -Fast_UnicodeUtil::GetUTF8CharNonAscii(unsigned const char *&src) +Fast_UnicodeUtil::GetUTF8CharNonAscii(unsigned const char *&src) noexcept { ucs4_t retval; @@ -222,7 +221,7 @@ Fast_UnicodeUtil::GetUTF8CharNonAscii(unsigned const char *&src) } ucs4_t -Fast_UnicodeUtil::GetUTF8Char(unsigned const char *&src) +Fast_UnicodeUtil::GetUTF8Char(unsigned const char *&src) noexcept { return (*src >= 0x80) ? GetUTF8CharNonAscii(src) @@ -246,7 +245,7 @@ Fast_UnicodeUtil::GetUTF8Char(unsigned const char *&src) #define UTF8_STARTCHAR(c) (!((c) & 0x80) || ((c) & 0x40)) int Fast_UnicodeUtil::UTF8move(unsigned const char* start, size_t length, - unsigned const char*& pos, off_t offset) + unsigned const char*& pos, off_t offset) noexcept { int increment = offset > 0 ? 1 : -1; unsigned const char* p = pos; diff --git a/vespalib/src/vespa/fastlib/text/unicodeutil.h b/vespalib/src/vespa/fastlib/text/unicodeutil.h index 87c09826948..740cc9381b7 100644 --- a/vespalib/src/vespa/fastlib/text/unicodeutil.h +++ b/vespalib/src/vespa/fastlib/text/unicodeutil.h @@ -16,7 +16,7 @@ using ucs4_t = uint32_t; * Used to examine properties of unicode characters, and * provide fast conversion methods between often used encodings. */ -class Fast_UnicodeUtil { +class Fast_UnicodeUtil final { private: /** * Is true when the tables have been initialized. Is set by @@ -46,9 +46,8 @@ private: }; public: - virtual ~Fast_UnicodeUtil() { } /** Initialize the ISO 8859-1 static tables. */ - static void InitTables(); + static void InitTables() noexcept; /** Indicates an invalid UTF-8 character sequence. */ enum { _BadUTF8Char = 0xfffffffeu }; @@ -64,7 +63,7 @@ public: * one or more of the properties alphabetic, ideographic, * combining char, decimal digit char, private use, extender. */ - static bool IsWordChar(ucs4_t testchar) { + static bool IsWordChar(ucs4_t testchar) noexcept { return (testchar < 65536 && (_compCharProps[testchar >> 8][testchar & 255] & _wordcharProp) != 0); @@ -80,8 +79,8 @@ public: * @return The next UCS4 character, or _BadUTF8Char if the * next character is invalid. */ - static ucs4_t GetUTF8Char(const unsigned char *& src); - static ucs4_t GetUTF8Char(const char *& src) { + static ucs4_t GetUTF8Char(const unsigned char *& src) noexcept; + static ucs4_t GetUTF8Char(const char *& src) noexcept { const unsigned char *temp = reinterpret_cast<const unsigned char *>(src); ucs4_t res = GetUTF8Char(temp); src = reinterpret_cast<const char *>(temp); @@ -94,7 +93,7 @@ public: * @param i The UCS4 character. * @return Pointer to the next position in dst after the putted byte(s). */ - static char *utf8cput(char *dst, ucs4_t i) { + static char *utf8cput(char *dst, ucs4_t i) noexcept { if (i < 128) *dst++ = i; else if (i < 0x800) { @@ -132,14 +131,14 @@ public: * @param src The UTF-8 source buffer. * @return A pointer to the destination string. */ - static ucs4_t *ucs4copy(ucs4_t *dst, const char *src); + static ucs4_t *ucs4copy(ucs4_t *dst, const char *src) noexcept; /** * Get the length of the UTF-8 representation of an UCS4 character. * @param i The UCS4 character. * @return The number of bytes required for the UTF-8 representation. */ - static size_t utf8clen(ucs4_t i) { + static size_t utf8clen(ucs4_t i) noexcept { if (i < 128) return 1; else if (i < 0x800) @@ -159,7 +158,7 @@ public: * @param testchar The character to lowercase. * @return The lowercase of the input, if defined. Else the input character. */ - static ucs4_t ToLower(ucs4_t testchar) + static ucs4_t ToLower(ucs4_t testchar) noexcept { ucs4_t ret; if (testchar < 65536) { @@ -182,14 +181,14 @@ public: * @return Number of bytes moved, or -1 if out of range */ static int UTF8move(unsigned const char* start, size_t length, - unsigned const char*& pos, off_t offset); + unsigned const char*& pos, off_t offset) noexcept; /** * Find the number of characters in an UCS4 string. * @param str The UCS4 string. * @return The number of characters. */ - static size_t ucs4strlen(const ucs4_t *str); + static size_t ucs4strlen(const ucs4_t *str) noexcept; /** * Convert UCS4 to UTF-8, bounded by max lengths. @@ -199,7 +198,7 @@ public: * @param maxsrc The maximum number of characters to convert from src. * @return A pointer to the destination. */ - static char *utf8ncopy(char *dst, const ucs4_t *src, int maxdst, int maxsrc); + static char *utf8ncopy(char *dst, const ucs4_t *src, int maxdst, int maxsrc) noexcept; /** @@ -210,7 +209,7 @@ public: * if s1 is, respectively, less than, matching, or greater than s2. * NB Only used in local test */ - static int utf8cmp(const char *s1, const ucs4_t *s2); + static int utf8cmp(const char *s1, const ucs4_t *s2) noexcept; /** * Test for terminal punctuation. @@ -218,7 +217,7 @@ public: * @return true if testchar is a terminal punctuation character, * i.e. if it has the terminal punctuation char property. */ - static bool IsTerminalPunctuationChar(ucs4_t testchar) { + static bool IsTerminalPunctuationChar(ucs4_t testchar) noexcept { return (testchar < 65536 && (_compCharProps[testchar >> 8][testchar & 255] & _terminalPunctuationCharProp) != 0); @@ -233,10 +232,10 @@ public: * @return The next UCS4 character, or _BadUTF8Char if the * next character is invalid. */ - static ucs4_t GetUTF8CharNonAscii(unsigned const char *&src); + static ucs4_t GetUTF8CharNonAscii(unsigned const char *&src) noexcept; // this is really an alias of the above function - static ucs4_t GetUTF8CharNonAscii(const char *&src) { + static ucs4_t GetUTF8CharNonAscii(const char *&src) noexcept { unsigned const char *temp = reinterpret_cast<unsigned const char *>(src); ucs4_t res = GetUTF8CharNonAscii(temp); src = reinterpret_cast<const char *>(temp); diff --git a/vespalib/src/vespa/vespalib/datastore/unique_store_hash_dictionary_read_snapshot.hpp b/vespalib/src/vespa/vespalib/datastore/unique_store_hash_dictionary_read_snapshot.hpp index f416f329331..d3349044fd9 100644 --- a/vespalib/src/vespa/vespalib/datastore/unique_store_hash_dictionary_read_snapshot.hpp +++ b/vespalib/src/vespa/vespalib/datastore/unique_store_hash_dictionary_read_snapshot.hpp @@ -3,6 +3,7 @@ #pragma once #include "unique_store_hash_dictionary_read_snapshot.h" +#include <algorithm> namespace vespalib::datastore { diff --git a/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.h b/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.h index 095da1d7c7c..490582b5bf7 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.h +++ b/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.h @@ -101,7 +101,7 @@ public: explicit ExplicitLevenshteinDfaImpl(bool is_cased) noexcept : _is_cased(is_cased) {} - ~ExplicitLevenshteinDfaImpl() override = default; + ~ExplicitLevenshteinDfaImpl() override; static constexpr uint8_t max_edits() noexcept { return MaxEdits; } diff --git a/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.hpp b/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.hpp index 5265178cef4..55dd459ff26 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.hpp +++ b/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.hpp @@ -94,6 +94,9 @@ struct ExplicitDfaMatcher { }; template <uint8_t MaxEdits> +ExplicitLevenshteinDfaImpl<MaxEdits>::~ExplicitLevenshteinDfaImpl() = default; + +template <uint8_t MaxEdits> LevenshteinDfa::MatchResult ExplicitLevenshteinDfaImpl<MaxEdits>::match(std::string_view u8str) const { ExplicitDfaMatcher<MaxEdits> matcher(_nodes, _is_cased); diff --git a/vespalib/src/vespa/vespalib/portal/portal.cpp b/vespalib/src/vespa/vespalib/portal/portal.cpp index 8e91e2b5caf..32cc9e4c644 100644 --- a/vespalib/src/vespa/vespalib/portal/portal.cpp +++ b/vespalib/src/vespa/vespalib/portal/portal.cpp @@ -4,6 +4,7 @@ #include "http_connection.h" #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/util/host_name.h> +#include <algorithm> #include <cassert> namespace vespalib { |