diff options
159 files changed, 4108 insertions, 2271 deletions
diff --git a/client/go/internal/vespa/crypto.go b/client/go/internal/vespa/crypto.go index 9b4d776d97d..568d7a84d18 100644 --- a/client/go/internal/vespa/crypto.go +++ b/client/go/internal/vespa/crypto.go @@ -13,6 +13,7 @@ import ( "encoding/base64" "encoding/hex" "encoding/pem" + "errors" "fmt" "io" "math/big" @@ -220,3 +221,15 @@ func randomSerialNumber() (*big.Int, error) { serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) return rand.Int(rand.Reader, serialNumberLimit) } + +// isTLSAlert returns whether err contains a TLS alert error. +func isTLSAlert(err error) bool { + for ; err != nil; err = errors.Unwrap(err) { + // This is ugly, but alert types are currently not exposed: + // https://github.com/golang/go/issues/35234 + if fmt.Sprintf("%T", err) == "tls.alert" { + return true + } + } + return false +} diff --git a/client/go/internal/vespa/target.go b/client/go/internal/vespa/target.go index 90d1e1997da..ed3cb146eb1 100644 --- a/client/go/internal/vespa/target.go +++ b/client/go/internal/vespa/target.go @@ -153,7 +153,11 @@ func (s *Service) Do(request *http.Request, timeout time.Duration) (*http.Respon if err := s.CurlWriter.print(request, s.TLSOptions, timeout); err != nil { return nil, err } - return s.httpClient.Do(request, timeout) + resp, err := s.httpClient.Do(request, timeout) + if isTLSAlert(err) { + return nil, fmt.Errorf("%w: %s", errAuth, err) + } + return resp, err } // SetClient sets a custom HTTP client that this service should use. diff --git a/client/js/app/yarn.lock b/client/js/app/yarn.lock index e9dc5bf25fe..ebf1c99db13 100644 --- a/client/js/app/yarn.lock +++ b/client/js/app/yarn.lock @@ -1311,70 +1311,70 @@ resolved "https://registry.yarnpkg.com/@remix-run/router/-/router-1.14.2.tgz#4d58f59908d9197ba3179310077f25c88e49ed17" integrity sha512-ACXpdMM9hmKZww21yEqWwiLws/UPLhNKvimN8RrYSqPSvB3ov7sLvAcfvaxePeLvccTQKGdkDIhLYApZVDFuKg== -"@rollup/rollup-android-arm-eabi@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.9.4.tgz#b1094962742c1a0349587040bc06185e2a667c9b" - integrity sha512-ub/SN3yWqIv5CWiAZPHVS1DloyZsJbtXmX4HxUTIpS0BHm9pW5iYBo2mIZi+hE3AeiTzHz33blwSnhdUo+9NpA== - -"@rollup/rollup-android-arm64@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.9.4.tgz#96eb86fb549e05b187f2ad06f51d191a23cb385a" - integrity sha512-ehcBrOR5XTl0W0t2WxfTyHCR/3Cq2jfb+I4W+Ch8Y9b5G+vbAecVv0Fx/J1QKktOrgUYsIKxWAKgIpvw56IFNA== - -"@rollup/rollup-darwin-arm64@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.9.4.tgz#2456630c007cc5905cb368acb9ff9fc04b2d37be" - integrity sha512-1fzh1lWExwSTWy8vJPnNbNM02WZDS8AW3McEOb7wW+nPChLKf3WG2aG7fhaUmfX5FKw9zhsF5+MBwArGyNM7NA== - -"@rollup/rollup-darwin-x64@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.9.4.tgz#97742214fc7dfd47a0f74efba6f5ae264e29c70c" - integrity sha512-Gc6cukkF38RcYQ6uPdiXi70JB0f29CwcQ7+r4QpfNpQFVHXRd0DfWFidoGxjSx1DwOETM97JPz1RXL5ISSB0pA== - -"@rollup/rollup-linux-arm-gnueabihf@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.9.4.tgz#cd933e61d6f689c9cdefde424beafbd92cfe58e2" - integrity sha512-g21RTeFzoTl8GxosHbnQZ0/JkuFIB13C3T7Y0HtKzOXmoHhewLbVTFBQZu+z5m9STH6FZ7L/oPgU4Nm5ErN2fw== - -"@rollup/rollup-linux-arm64-gnu@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.9.4.tgz#33b09bf462f1837afc1e02a1b352af6b510c78a6" - integrity sha512-TVYVWD/SYwWzGGnbfTkrNpdE4HON46orgMNHCivlXmlsSGQOx/OHHYiQcMIOx38/GWgwr/po2LBn7wypkWw/Mg== - -"@rollup/rollup-linux-arm64-musl@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.9.4.tgz#50257fb248832c2308064e3764a16273b6ee4615" - integrity sha512-XcKvuendwizYYhFxpvQ3xVpzje2HHImzg33wL9zvxtj77HvPStbSGI9czrdbfrf8DGMcNNReH9pVZv8qejAQ5A== - -"@rollup/rollup-linux-riscv64-gnu@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.9.4.tgz#09589e4e1a073cf56f6249b77eb6c9a8e9b613a8" - integrity sha512-LFHS/8Q+I9YA0yVETyjonMJ3UA+DczeBd/MqNEzsGSTdNvSJa1OJZcSH8GiXLvcizgp9AlHs2walqRcqzjOi3A== - -"@rollup/rollup-linux-x64-gnu@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.9.4.tgz#bd312bb5b5f02e54d15488605d15cfd3f90dda7c" - integrity sha512-dIYgo+j1+yfy81i0YVU5KnQrIJZE8ERomx17ReU4GREjGtDW4X+nvkBak2xAUpyqLs4eleDSj3RrV72fQos7zw== - -"@rollup/rollup-linux-x64-musl@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.9.4.tgz#25b3bede85d86438ce28cc642842d10d867d40e9" - integrity sha512-RoaYxjdHQ5TPjaPrLsfKqR3pakMr3JGqZ+jZM0zP2IkDtsGa4CqYaWSfQmZVgFUCgLrTnzX+cnHS3nfl+kB6ZQ== - -"@rollup/rollup-win32-arm64-msvc@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.9.4.tgz#95957067eb107f571da1d81939f017d37b4958d3" - integrity sha512-T8Q3XHV+Jjf5e49B4EAaLKV74BbX7/qYBRQ8Wop/+TyyU0k+vSjiLVSHNWdVd1goMjZcbhDmYZUYW5RFqkBNHQ== - -"@rollup/rollup-win32-ia32-msvc@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.9.4.tgz#71b6facad976db527863f698692c6964c0b6e10e" - integrity sha512-z+JQ7JirDUHAsMecVydnBPWLwJjbppU+7LZjffGf+Jvrxq+dVjIE7By163Sc9DKc3ADSU50qPVw0KonBS+a+HQ== - -"@rollup/rollup-win32-x64-msvc@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.9.4.tgz#16295ccae354707c9bc6842906bdeaad4f3ba7a5" - integrity sha512-LfdGXCV9rdEify1oxlN9eamvDSjv9md9ZVMAbNHA87xqIfFCxImxan9qZ8+Un54iK2nnqPlbnSi4R54ONtbWBw== +"@rollup/rollup-android-arm-eabi@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.9.5.tgz#b752b6c88a14ccfcbdf3f48c577ccc3a7f0e66b9" + integrity sha512-idWaG8xeSRCfRq9KpRysDHJ/rEHBEXcHuJ82XY0yYFIWnLMjZv9vF/7DOq8djQ2n3Lk6+3qfSH8AqlmHlmi1MA== + +"@rollup/rollup-android-arm64@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.9.5.tgz#33757c3a448b9ef77b6f6292d8b0ec45c87e9c1a" + integrity sha512-f14d7uhAMtsCGjAYwZGv6TwuS3IFaM4ZnGMUn3aCBgkcHAYErhV1Ad97WzBvS2o0aaDv4mVz+syiN0ElMyfBPg== + +"@rollup/rollup-darwin-arm64@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.9.5.tgz#5234ba62665a3f443143bc8bcea9df2cc58f55fb" + integrity sha512-ndoXeLx455FffL68OIUrVr89Xu1WLzAG4n65R8roDlCoYiQcGGg6MALvs2Ap9zs7AHg8mpHtMpwC8jBBjZrT/w== + +"@rollup/rollup-darwin-x64@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.9.5.tgz#981256c054d3247b83313724938d606798a919d1" + integrity sha512-UmElV1OY2m/1KEEqTlIjieKfVwRg0Zwg4PLgNf0s3glAHXBN99KLpw5A5lrSYCa1Kp63czTpVll2MAqbZYIHoA== + +"@rollup/rollup-linux-arm-gnueabihf@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.9.5.tgz#120678a5a2b3a283a548dbb4d337f9187a793560" + integrity sha512-Q0LcU61v92tQB6ae+udZvOyZ0wfpGojtAKrrpAaIqmJ7+psq4cMIhT/9lfV6UQIpeItnq/2QDROhNLo00lOD1g== + +"@rollup/rollup-linux-arm64-gnu@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.9.5.tgz#c99d857e2372ece544b6f60b85058ad259f64114" + integrity sha512-dkRscpM+RrR2Ee3eOQmRWFjmV/payHEOrjyq1VZegRUa5OrZJ2MAxBNs05bZuY0YCtpqETDy1Ix4i/hRqX98cA== + +"@rollup/rollup-linux-arm64-musl@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.9.5.tgz#3064060f568a5718c2a06858cd6e6d24f2ff8632" + integrity sha512-QaKFVOzzST2xzY4MAmiDmURagWLFh+zZtttuEnuNn19AiZ0T3fhPyjPPGwLNdiDT82ZE91hnfJsUiDwF9DClIQ== + +"@rollup/rollup-linux-riscv64-gnu@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.9.5.tgz#987d30b5d2b992fff07d055015991a57ff55fbad" + integrity sha512-HeGqmRJuyVg6/X6MpE2ur7GbymBPS8Np0S/vQFHDmocfORT+Zt76qu+69NUoxXzGqVP1pzaY6QIi0FJWLC3OPA== + +"@rollup/rollup-linux-x64-gnu@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.9.5.tgz#85946ee4d068bd12197aeeec2c6f679c94978a49" + integrity sha512-Dq1bqBdLaZ1Gb/l2e5/+o3B18+8TI9ANlA1SkejZqDgdU/jK/ThYaMPMJpVMMXy2uRHvGKbkz9vheVGdq3cJfA== + +"@rollup/rollup-linux-x64-musl@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.9.5.tgz#fe0b20f9749a60eb1df43d20effa96c756ddcbd4" + integrity sha512-ezyFUOwldYpj7AbkwyW9AJ203peub81CaAIVvckdkyH8EvhEIoKzaMFJj0G4qYJ5sw3BpqhFrsCc30t54HV8vg== + +"@rollup/rollup-win32-arm64-msvc@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.9.5.tgz#422661ef0e16699a234465d15b2c1089ef963b2a" + integrity sha512-aHSsMnUw+0UETB0Hlv7B/ZHOGY5bQdwMKJSzGfDfvyhnpmVxLMGnQPGNE9wgqkLUs3+gbG1Qx02S2LLfJ5GaRQ== + +"@rollup/rollup-win32-ia32-msvc@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.9.5.tgz#7b73a145891c202fbcc08759248983667a035d85" + integrity sha512-AiqiLkb9KSf7Lj/o1U3SEP9Zn+5NuVKgFdRIZkvd4N0+bYrTOovVd0+LmYCPQGbocT4kvFyK+LXCDiXPBF3fyA== + +"@rollup/rollup-win32-x64-msvc@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.9.5.tgz#10491ccf4f63c814d4149e0316541476ea603602" + integrity sha512-1q+mykKE3Vot1kaFJIDoUFv5TuW+QQVaf2FmTT9krg86pQrGStOSJJ0Zil7CFagyxDuouTepzt5Y5TVzyajOdQ== "@sinclair/typebox@^0.27.8": version "0.27.8" @@ -4670,17 +4670,17 @@ react-refresh@^0.14.0: integrity sha512-wViHqhAd8OHeLS/IRMJjTSDHF3U9eWi62F/MledQGPdJGDhodXJ9PBLNGr6WWL7qlH12Mt3TyTpbS+hGXMjCzQ== react-router-dom@^6: - version "6.21.2" - resolved "https://registry.yarnpkg.com/react-router-dom/-/react-router-dom-6.21.2.tgz#5fba851731a194fa32c31990c4829c5e247f650a" - integrity sha512-tE13UukgUOh2/sqYr6jPzZTzmzc70aGRP4pAjG2if0IP3aUT+sBtAKUJh0qMh0zylJHGLmzS+XWVaON4UklHeg== + version "6.21.3" + resolved "https://registry.yarnpkg.com/react-router-dom/-/react-router-dom-6.21.3.tgz#ef3a7956a3699c7b82c21fcb3dbc63c313ed8c5d" + integrity sha512-kNzubk7n4YHSrErzjLK72j0B5i969GsuCGazRl3G6j1zqZBLjuSlYBdVdkDOgzGdPIffUOc9nmgiadTEVoq91g== dependencies: "@remix-run/router" "1.14.2" - react-router "6.21.2" + react-router "6.21.3" -react-router@6.21.2: - version "6.21.2" - resolved "https://registry.yarnpkg.com/react-router/-/react-router-6.21.2.tgz#8820906c609ae7e4e8f926cc8eb5ce161428b956" - integrity sha512-jJcgiwDsnaHIeC+IN7atO0XiSRCrOsQAHHbChtJxmgqG2IaYQXSnhqGb5vk2CU/wBQA12Zt+TkbuJjIn65gzbA== +react-router@6.21.3: + version "6.21.3" + resolved "https://registry.yarnpkg.com/react-router/-/react-router-6.21.3.tgz#8086cea922c2bfebbb49c6594967418f1f167d70" + integrity sha512-a0H638ZXULv1OdkmiK6s6itNhoy33ywxmUFT/xtSoVyf9VnC7n7+VT4LjVzdIHSaF5TIh9ylUgxMXksHTgGrKg== dependencies: "@remix-run/router" "1.14.2" @@ -4845,25 +4845,25 @@ rimraf@^3.0.2: glob "^7.1.3" rollup@^4.2.0: - version "4.9.4" - resolved "https://registry.yarnpkg.com/rollup/-/rollup-4.9.4.tgz#37bc0c09ae6b4538a9c974f4d045bb64b2e7c27c" - integrity sha512-2ztU7pY/lrQyXSCnnoU4ICjT/tCG9cdH3/G25ERqE3Lst6vl2BCM5hL2Nw+sslAvAf+ccKsAq1SkKQALyqhR7g== + version "4.9.5" + resolved "https://registry.yarnpkg.com/rollup/-/rollup-4.9.5.tgz#62999462c90f4c8b5d7c38fc7161e63b29101b05" + integrity sha512-E4vQW0H/mbNMw2yLSqJyjtkHY9dslf/p0zuT1xehNRqUTBOFMqEjguDvqhXr7N7r/4ttb2jr4T41d3dncmIgbQ== dependencies: "@types/estree" "1.0.5" optionalDependencies: - "@rollup/rollup-android-arm-eabi" "4.9.4" - "@rollup/rollup-android-arm64" "4.9.4" - "@rollup/rollup-darwin-arm64" "4.9.4" - "@rollup/rollup-darwin-x64" "4.9.4" - "@rollup/rollup-linux-arm-gnueabihf" "4.9.4" - "@rollup/rollup-linux-arm64-gnu" "4.9.4" - "@rollup/rollup-linux-arm64-musl" "4.9.4" - "@rollup/rollup-linux-riscv64-gnu" "4.9.4" - "@rollup/rollup-linux-x64-gnu" "4.9.4" - "@rollup/rollup-linux-x64-musl" "4.9.4" - "@rollup/rollup-win32-arm64-msvc" "4.9.4" - "@rollup/rollup-win32-ia32-msvc" "4.9.4" - "@rollup/rollup-win32-x64-msvc" "4.9.4" + "@rollup/rollup-android-arm-eabi" "4.9.5" + "@rollup/rollup-android-arm64" "4.9.5" + "@rollup/rollup-darwin-arm64" "4.9.5" + "@rollup/rollup-darwin-x64" "4.9.5" + "@rollup/rollup-linux-arm-gnueabihf" "4.9.5" + "@rollup/rollup-linux-arm64-gnu" "4.9.5" + "@rollup/rollup-linux-arm64-musl" "4.9.5" + "@rollup/rollup-linux-riscv64-gnu" "4.9.5" + "@rollup/rollup-linux-x64-gnu" "4.9.5" + "@rollup/rollup-linux-x64-musl" "4.9.5" + "@rollup/rollup-win32-arm64-msvc" "4.9.5" + "@rollup/rollup-win32-ia32-msvc" "4.9.5" + "@rollup/rollup-win32-x64-msvc" "4.9.5" fsevents "~2.3.2" rsvp@^4.8.4: @@ -5474,9 +5474,9 @@ v8-to-istanbul@^9.0.1: convert-source-map "^1.6.0" vite@^5.0.5: - version "5.0.11" - resolved "https://registry.yarnpkg.com/vite/-/vite-5.0.11.tgz#31562e41e004cb68e1d51f5d2c641ab313b289e4" - integrity sha512-XBMnDjZcNAw/G1gEiskiM1v6yzM4GE5aMGvhWTlHAYYhxb7S3/V1s3m2LDHa8Vh6yIWYYB0iJwsEaS523c4oYA== + version "5.0.12" + resolved "https://registry.yarnpkg.com/vite/-/vite-5.0.12.tgz#8a2ffd4da36c132aec4adafe05d7adde38333c47" + integrity sha512-4hsnEkG3q0N4Tzf1+t6NdN9dg/L3BM+q8SWgbSPnJvrgH2kgdyzfVJwbR1ic69/4uMJJ/3dqDZZE5/WwqW8U1w== dependencies: esbuild "^0.19.3" postcss "^8.4.32" diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java index eb5942bd49c..a6910c059fc 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java @@ -112,8 +112,8 @@ public interface ModelContext { @ModelFeatureFlag(owners = {"bjorncs"}, removeAfter = "8.289") default boolean dynamicHeapSize() { return true; } @ModelFeatureFlag(owners = {"hmusum"}) default String unknownConfigDefinition() { return "warn"; } @ModelFeatureFlag(owners = {"hmusum"}) default int searchHandlerThreadpool() { return 2; } - @ModelFeatureFlag(owners = {"vekterli"}) default long mergingMaxMemoryUsagePerNode() { return -1; } - @ModelFeatureFlag(owners = {"vekterli"}) default boolean usePerDocumentThrottledDeleteBucket() { return false; } + @ModelFeatureFlag(owners = {"vekterli"}, removeAfter = "8.292.x") default long mergingMaxMemoryUsagePerNode() { return 0; } + @ModelFeatureFlag(owners = {"vekterli"}, removeAfter = "8.292.x") default boolean usePerDocumentThrottledDeleteBucket() { return true; } @ModelFeatureFlag(owners = {"baldersheim"}) default boolean alwaysMarkPhraseExpensive() { return false; } @ModelFeatureFlag(owners = {"hmusum"}) default boolean restartOnDeployWhenOnnxModelChanges() { return false; } @ModelFeatureFlag(owners = {"baldersheim"}) default boolean sortBlueprintsByCost() { return false; } diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java index 6f0254145d6..b088231e84a 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java @@ -84,8 +84,6 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea private List<DataplaneToken> dataplaneTokens; private int contentLayerMetadataFeatureLevel = 0; private boolean dynamicHeapSize = false; - private long mergingMaxMemoryUsagePerNode = -1; - private boolean usePerDocumentThrottledDeleteBucket = false; private boolean restartOnDeployWhenOnnxModelChanges = false; @Override public ModelContext.FeatureFlags featureFlags() { return this; } @@ -144,8 +142,6 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea @Override public List<DataplaneToken> dataplaneTokens() { return dataplaneTokens; } @Override public int contentLayerMetadataFeatureLevel() { return contentLayerMetadataFeatureLevel; } @Override public boolean dynamicHeapSize() { return dynamicHeapSize; } - @Override public long mergingMaxMemoryUsagePerNode() { return mergingMaxMemoryUsagePerNode; } - @Override public boolean usePerDocumentThrottledDeleteBucket() { return usePerDocumentThrottledDeleteBucket; } @Override public boolean restartOnDeployWhenOnnxModelChanges() { return restartOnDeployWhenOnnxModelChanges; } public TestProperties sharedStringRepoNoReclaim(boolean sharedStringRepoNoReclaim) { @@ -379,16 +375,6 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea public TestProperties setDynamicHeapSize(boolean b) { this.dynamicHeapSize = b; return this; } - public TestProperties setMergingMaxMemoryUsagePerNode(long maxUsage) { - this.mergingMaxMemoryUsagePerNode = maxUsage; - return this; - } - - public TestProperties setUsePerDocumentThrottledDeleteBucket(boolean enableThrottling) { - this.usePerDocumentThrottledDeleteBucket = enableThrottling; - return this; - } - public TestProperties setRestartOnDeployForOnnxModelChanges(boolean enable) { this.restartOnDeployWhenOnnxModelChanges = enable; return this; diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index 502b054f84e..9b3e236612a 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -615,18 +615,6 @@ public class RankProfile implements Cloneable { .orElse(Set.of()); } - private void addSummaryFeature(ReferenceNode feature) { - if (summaryFeatures == null) - summaryFeatures = new LinkedHashSet<>(); - summaryFeatures.add(feature); - } - - private void addMatchFeature(ReferenceNode feature) { - if (matchFeatures == null) - matchFeatures = new LinkedHashSet<>(); - matchFeatures.add(feature); - } - private void addImplicitMatchFeatures(List<FeatureList> list) { if (hiddenMatchFeatures == null) hiddenMatchFeatures = new LinkedHashSet<>(); @@ -642,15 +630,19 @@ public class RankProfile implements Cloneable { /** Adds the content of the given feature list to the internal list of summary features. */ public void addSummaryFeatures(FeatureList features) { + if (summaryFeatures == null) + summaryFeatures = new LinkedHashSet<>(); for (ReferenceNode feature : features) { - addSummaryFeature(feature); + summaryFeatures.add(feature); } } /** Adds the content of the given feature list to the internal list of match features. */ public void addMatchFeatures(FeatureList features) { + if (matchFeatures == null) + matchFeatures = new LinkedHashSet<>(); for (ReferenceNode feature : features) { - addMatchFeature(feature); + matchFeatures.add(feature); } } @@ -661,20 +653,16 @@ public class RankProfile implements Cloneable { .orElse(Set.of()); } - private void addRankFeature(ReferenceNode feature) { - if (rankFeatures == null) - rankFeatures = new LinkedHashSet<>(); - rankFeatures.add(feature); - } - /** * Adds the content of the given feature list to the internal list of rank features. * * @param features The features to add. */ public void addRankFeatures(FeatureList features) { + if (rankFeatures == null) + rankFeatures = new LinkedHashSet<>(); for (ReferenceNode feature : features) { - addRankFeature(feature); + rankFeatures.add(feature); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/FileStorProducer.java b/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/FileStorProducer.java index 18b9129cead..56ca23523b6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/FileStorProducer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/FileStorProducer.java @@ -47,7 +47,6 @@ public class FileStorProducer implements StorFilestorConfig.Producer { private final int responseNumThreads; private final StorFilestorConfig.Response_sequencer_type.Enum responseSequencerType; private final boolean useAsyncMessageHandlingOnSchedule; - private final boolean usePerDocumentThrottledDeleteBucket; private static StorFilestorConfig.Response_sequencer_type.Enum convertResponseSequencerType(String sequencerType) { try { @@ -63,9 +62,9 @@ public class FileStorProducer implements StorFilestorConfig.Producer { this.responseNumThreads = featureFlags.defaultNumResponseThreads(); this.responseSequencerType = convertResponseSequencerType(featureFlags.responseSequencerType()); this.useAsyncMessageHandlingOnSchedule = featureFlags.useAsyncMessageHandlingOnSchedule(); - this.usePerDocumentThrottledDeleteBucket = featureFlags.usePerDocumentThrottledDeleteBucket(); } + @Override public void getConfig(StorFilestorConfig.Builder builder) { if (numThreads != null) { @@ -75,7 +74,7 @@ public class FileStorProducer implements StorFilestorConfig.Producer { builder.num_response_threads(responseNumThreads); builder.response_sequencer_type(responseSequencerType); builder.use_async_message_handling_on_schedule(useAsyncMessageHandlingOnSchedule); - builder.use_per_document_throttled_delete_bucket(usePerDocumentThrottledDeleteBucket); + builder.use_per_document_throttled_delete_bucket(true); var throttleBuilder = new StorFilestorConfig.Async_operation_throttler.Builder(); builder.async_operation_throttler(throttleBuilder); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorServerProducer.java b/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorServerProducer.java index 1865db0ec1c..81fc19de929 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorServerProducer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorServerProducer.java @@ -29,7 +29,6 @@ public class StorServerProducer implements StorServerConfig.Producer { private final String clusterName; private Integer maxMergesPerNode; private Integer queueSize; - private Long mergingMaxMemoryUsagePerNode; private StorServerProducer setMaxMergesPerNode(Integer value) { if (value != null) { @@ -46,7 +45,6 @@ public class StorServerProducer implements StorServerConfig.Producer { StorServerProducer(String clusterName, ModelContext.FeatureFlags featureFlags) { this.clusterName = clusterName; - this.mergingMaxMemoryUsagePerNode = featureFlags.mergingMaxMemoryUsagePerNode(); } @Override @@ -63,10 +61,8 @@ public class StorServerProducer implements StorServerConfig.Producer { if (queueSize != null) { builder.max_merge_queue_size(queueSize); } - if (mergingMaxMemoryUsagePerNode != null) { - builder.merge_throttling_memory_limit( - new StorServerConfig.Merge_throttling_memory_limit.Builder() - .max_usage_bytes(mergingMaxMemoryUsagePerNode)); - } + builder.merge_throttling_memory_limit( + new StorServerConfig.Merge_throttling_memory_limit.Builder() + .max_usage_bytes(0)); } } diff --git a/config-model/src/test/derived/rankprofileinheritance/child.sd b/config-model/src/test/derived/rankprofileinheritance/child.sd index 2517d0731f5..8348a62838c 100644 --- a/config-model/src/test/derived/rankprofileinheritance/child.sd +++ b/config-model/src/test/derived/rankprofileinheritance/child.sd @@ -39,4 +39,15 @@ schema child { } + rank-profile profile5 inherits profile1 { + match-features { + attribute(field3) + } + } + + rank-profile profile6 inherits profile1 { + summary-features { } + match-features { } + } + } diff --git a/config-model/src/test/derived/rankprofileinheritance/rank-profiles.cfg b/config-model/src/test/derived/rankprofileinheritance/rank-profiles.cfg index a3bc6791412..ccf52da3b5e 100644 --- a/config-model/src/test/derived/rankprofileinheritance/rank-profiles.cfg +++ b/config-model/src/test/derived/rankprofileinheritance/rank-profiles.cfg @@ -52,3 +52,23 @@ rankprofile[].fef.property[].name "vespa.feature.rename" rankprofile[].fef.property[].value "rankingExpression(function4)" rankprofile[].fef.property[].name "vespa.feature.rename" rankprofile[].fef.property[].value "function4" +rankprofile[].name "profile5" +rankprofile[].fef.property[].name "rankingExpression(function1).rankingScript" +rankprofile[].fef.property[].value "attribute(field1) + 5" +rankprofile[].fef.property[].name "rankingExpression(function1b).rankingScript" +rankprofile[].fef.property[].value "attribute(field1) + 42" +rankprofile[].fef.property[].name "vespa.summary.feature" +rankprofile[].fef.property[].value "attribute(field1)" +rankprofile[].fef.property[].name "vespa.summary.feature" +rankprofile[].fef.property[].value "rankingExpression(function1)" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "attribute(field3)" +rankprofile[].fef.property[].name "vespa.feature.rename" +rankprofile[].fef.property[].value "rankingExpression(function1)" +rankprofile[].fef.property[].name "vespa.feature.rename" +rankprofile[].fef.property[].value "function1" +rankprofile[].name "profile6" +rankprofile[].fef.property[].name "rankingExpression(function1).rankingScript" +rankprofile[].fef.property[].value "attribute(field1) + 5" +rankprofile[].fef.property[].name "rankingExpression(function1b).rankingScript" +rankprofile[].fef.property[].value "attribute(field1) + 42" diff --git a/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java b/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java index c959634019d..e920672646f 100644 --- a/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java @@ -1,11 +1,17 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.document.Document; import com.yahoo.schema.document.Stemming; import com.yahoo.schema.parser.ParseException; import com.yahoo.schema.processing.ImportedFieldsResolver; import com.yahoo.schema.processing.OnnxModelTypeResolver; import com.yahoo.vespa.documentmodel.DocumentSummary; +import com.yahoo.vespa.indexinglanguage.expressions.AttributeExpression; +import com.yahoo.vespa.indexinglanguage.expressions.Expression; +import com.yahoo.vespa.indexinglanguage.expressions.InputExpression; +import com.yahoo.vespa.indexinglanguage.expressions.ScriptExpression; +import com.yahoo.vespa.indexinglanguage.expressions.StatementExpression; import com.yahoo.vespa.model.test.utils.DeployLoggerStub; import org.junit.jupiter.api.Test; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/StorageClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/StorageClusterTest.java index bdd61d93136..7c3e66aa109 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/StorageClusterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/StorageClusterTest.java @@ -189,7 +189,7 @@ public class StorageClusterTest { var config = configFromProperties(new TestProperties()); var limit = config.merge_throttling_memory_limit(); - assertEquals(-1L, limit.max_usage_bytes()); // TODO change default + assertEquals(0L, limit.max_usage_bytes()); assertMergeAutoScaleConfigHasExpectedValues(limit); } @@ -200,21 +200,6 @@ public class StorageClusterTest { } @Test - void merge_throttler_memory_limit_is_controlled_by_feature_flag() { - var config = configFromProperties(new TestProperties().setMergingMaxMemoryUsagePerNode(-1)); - assertEquals(-1L, config.merge_throttling_memory_limit().max_usage_bytes()); - - config = configFromProperties(new TestProperties().setMergingMaxMemoryUsagePerNode(0)); - assertEquals(0L, config.merge_throttling_memory_limit().max_usage_bytes()); - - config = configFromProperties(new TestProperties().setMergingMaxMemoryUsagePerNode(1_234_456_789)); - assertEquals(1_234_456_789L, config.merge_throttling_memory_limit().max_usage_bytes()); - - // Feature flag should not affect the other config values - assertMergeAutoScaleConfigHasExpectedValues(config.merge_throttling_memory_limit()); - } - - @Test void testVisitors() { StorVisitorConfig.Builder builder = new StorVisitorConfig.Builder(); parse(cluster("bees", @@ -355,24 +340,6 @@ public class StorageClusterTest { assertTrue(config.async_operation_throttler().throttle_individual_merge_feed_ops()); } - private void verifyUsePerDocumentThrottledDeleteBucket(boolean expected, Boolean enabled) { - var props = new TestProperties(); - if (enabled != null) { - props.setUsePerDocumentThrottledDeleteBucket(enabled); - } - var config = filestorConfigFromProducer(simpleCluster(props)); - assertEquals(expected, config.use_per_document_throttled_delete_bucket()); - } - - @Test - void delete_bucket_throttling_is_controlled_by_feature_flag() { - // TODO update default once rolled out and tested - verifyUsePerDocumentThrottledDeleteBucket(false, null); - - verifyUsePerDocumentThrottledDeleteBucket(false, false); - verifyUsePerDocumentThrottledDeleteBucket(true, true); - } - @Test void testCapacity() { String xml = joinLines( diff --git a/configdefinitions/src/vespa/stor-filestor.def b/configdefinitions/src/vespa/stor-filestor.def index 090f74dec12..950797f8bc2 100644 --- a/configdefinitions/src/vespa/stor-filestor.def +++ b/configdefinitions/src/vespa/stor-filestor.def @@ -121,4 +121,4 @@ async_operation_dynamic_throttling_window_increment int default=20 restart ## first place. ## ## This is a live config. -use_per_document_throttled_delete_bucket bool default=false +use_per_document_throttled_delete_bucket bool default=true diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java index cac297f061a..32f4d2b653c 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java @@ -43,6 +43,7 @@ import com.yahoo.slime.Slime; import com.yahoo.transaction.NestedTransaction; import com.yahoo.transaction.Transaction; import com.yahoo.vespa.applicationmodel.InfrastructureApplication; +import com.yahoo.vespa.config.server.application.ActiveTokenFingerprints; import com.yahoo.vespa.config.server.application.ActiveTokenFingerprints.Token; import com.yahoo.vespa.config.server.application.ActiveTokenFingerprintsClient; import com.yahoo.vespa.config.server.application.Application; @@ -54,7 +55,6 @@ import com.yahoo.vespa.config.server.application.ClusterReindexing; import com.yahoo.vespa.config.server.application.ClusterReindexingStatusClient; import com.yahoo.vespa.config.server.application.CompressedApplicationInputStream; import com.yahoo.vespa.config.server.application.ConfigConvergenceChecker; -import com.yahoo.vespa.config.server.application.ActiveTokenFingerprints; import com.yahoo.vespa.config.server.application.DefaultClusterReindexingStatusClient; import com.yahoo.vespa.config.server.application.FileDistributionStatus; import com.yahoo.vespa.config.server.application.HttpProxy; @@ -111,7 +111,6 @@ import java.time.Duration; import java.time.Instant; import java.util.Collection; import java.util.Comparator; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -135,7 +134,6 @@ import static com.yahoo.vespa.config.server.tenant.TenantRepository.HOSTED_VESPA import static com.yahoo.vespa.curator.Curator.CompletionWaiter; import static com.yahoo.yolean.Exceptions.uncheck; import static java.nio.file.Files.readAttributes; -import static java.util.stream.Collectors.toMap; /** * The API for managing applications. @@ -544,6 +542,7 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye static void checkIfActiveHasChanged(Session session, Session activeSession, boolean ignoreStaleSessionFailure) { long activeSessionAtCreate = session.getActiveSessionAtCreate(); log.log(Level.FINE, () -> activeSession.logPre() + "active session id at create time=" + activeSessionAtCreate); + if (activeSessionAtCreate == 0) return; // No active session at create time, or session created for indeterminate app. long sessionId = session.getSessionId(); long activeSessionSessionId = activeSession.getSessionId(); @@ -554,7 +553,7 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye String errMsg = activeSession.logPre() + "Cannot activate session " + sessionId + " because the currently active session (" + activeSessionSessionId + ") has changed since session " + sessionId + " was created (was " + - (activeSessionAtCreate == 0 ? "empty" : activeSessionAtCreate) + " at creation time)"; + activeSessionAtCreate + " at creation time)"; if (ignoreStaleSessionFailure) { log.warning(errMsg + " (Continuing because of force.)"); } else { @@ -956,30 +955,25 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye } public void deleteExpiredLocalSessions() { - Map<Tenant, Collection<LocalSession>> sessionsPerTenant = new HashMap<>(); - tenantRepository.getAllTenants() - .forEach(tenant -> sessionsPerTenant.put(tenant, tenant.getSessionRepository().getLocalSessions())); - - Set<ApplicationId> applicationIds = new HashSet<>(); - sessionsPerTenant.values() - .forEach(sessionList -> sessionList.stream() - .map(Session::getOptionalApplicationId) - .filter(Optional::isPresent) - .forEach(appId -> applicationIds.add(appId.get()))); - - Map<ApplicationId, Long> activeSessions = new HashMap<>(); - applicationIds.forEach(applicationId -> getActiveSession(applicationId).ifPresent(session -> activeSessions.put(applicationId, session.getSessionId()))); - sessionsPerTenant.keySet().forEach(tenant -> tenant.getSessionRepository().deleteExpiredSessions(activeSessions)); + for (Tenant tenant : tenantRepository.getAllTenants()) { + tenant.getSessionRepository().deleteExpiredSessions(session -> sessionIsActiveForItsApplication(tenant, session)); + } } public int deleteExpiredRemoteSessions(Clock clock) { return tenantRepository.getAllTenants() .stream() - .map(tenant -> tenant.getSessionRepository().deleteExpiredRemoteSessions(clock)) + .map(tenant -> tenant.getSessionRepository().deleteExpiredRemoteSessions(clock, session -> sessionIsActiveForItsApplication(tenant, session))) .mapToInt(i -> i) .sum(); } + private boolean sessionIsActiveForItsApplication(Tenant tenant, Session session) { + Optional<ApplicationId> owner = session.getOptionalApplicationId(); + if (owner.isEmpty()) return true; // Chicken out ~(˘▾˘)~ + return tenant.getApplicationRepo().activeSessionOf(owner.get()).equals(Optional.of(session.getSessionId())); + } + // ---------------- Tenant operations ---------------------------------------------------------------- diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java index 7c7608e5e5c..c219d59a44e 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java @@ -177,7 +177,7 @@ public class Deployment implements com.yahoo.config.provision.Deployment { nodesToRestart.size(), nodesToRestart.stream().sorted().collect(joining(", ")))); log.info(String.format("%sWill schedule service restart of %d nodes after convergence on generation %d: %s", session.logPre(), nodesToRestart.size(), session.getSessionId(), nodesToRestart.stream().sorted().collect(joining(", ")))); - this.configChangeActions = configChangeActions.withRestartActions(new RestartActions()); + configChangeActions = configChangeActions == null ? null : configChangeActions.withRestartActions(new RestartActions()); } private void storeReindexing(ApplicationId applicationId) { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java index 22b2b581b44..d500e56d079 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java @@ -201,13 +201,11 @@ public class ModelContextImpl implements ModelContext { private final int rpc_events_before_wakeup; private final int heapPercentage; private final String summaryDecodePolicy; - private boolean sortBlueprintsByCost; + private final boolean sortBlueprintsByCost; private final boolean alwaysMarkPhraseExpensive; private final int contentLayerMetadataFeatureLevel; private final String unknownConfigDefinition; private final int searchHandlerThreadpool; - private final long mergingMaxMemoryUsagePerNode; - private final boolean usePerDocumentThrottledDeleteBucket; private final boolean restartOnDeployWhenOnnxModelChanges; public FeatureFlags(FlagSource source, ApplicationId appId, Version version) { @@ -249,8 +247,6 @@ public class ModelContextImpl implements ModelContext { this.contentLayerMetadataFeatureLevel = flagValue(source, appId, version, Flags.CONTENT_LAYER_METADATA_FEATURE_LEVEL); this.unknownConfigDefinition = flagValue(source, appId, version, Flags.UNKNOWN_CONFIG_DEFINITION); this.searchHandlerThreadpool = flagValue(source, appId, version, Flags.SEARCH_HANDLER_THREADPOOL); - this.mergingMaxMemoryUsagePerNode = flagValue(source, appId, version, Flags.MERGING_MAX_MEMORY_USAGE_PER_NODE); - this.usePerDocumentThrottledDeleteBucket = flagValue(source, appId, version, Flags.USE_PER_DOCUMENT_THROTTLED_DELETE_BUCKET); this.alwaysMarkPhraseExpensive = flagValue(source, appId, version, Flags.ALWAYS_MARK_PHRASE_EXPENSIVE); this.restartOnDeployWhenOnnxModelChanges = flagValue(source, appId, version, Flags.RESTART_ON_DEPLOY_WHEN_ONNX_MODEL_CHANGES); this.sortBlueprintsByCost = flagValue(source, appId, version, Flags.SORT_BLUEPRINTS_BY_COST); @@ -303,8 +299,6 @@ public class ModelContextImpl implements ModelContext { @Override public int contentLayerMetadataFeatureLevel() { return contentLayerMetadataFeatureLevel; } @Override public String unknownConfigDefinition() { return unknownConfigDefinition; } @Override public int searchHandlerThreadpool() { return searchHandlerThreadpool; } - @Override public long mergingMaxMemoryUsagePerNode() { return mergingMaxMemoryUsagePerNode; } - @Override public boolean usePerDocumentThrottledDeleteBucket() { return usePerDocumentThrottledDeleteBucket; } @Override public boolean restartOnDeployWhenOnnxModelChanges() { return restartOnDeployWhenOnnxModelChanges; } @Override public boolean sortBlueprintsByCost() { return sortBlueprintsByCost; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java index dcc5d7caa0d..4f10b1215cf 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.config.server.maintenance; import com.yahoo.config.FileReference; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.TenantName; import com.yahoo.config.subscription.ConfigSourceSet; import com.yahoo.jrt.Supervisor; import com.yahoo.jrt.Transport; @@ -19,8 +20,10 @@ import com.yahoo.vespa.filedistribution.FileReferenceDownload; import java.io.File; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.concurrent.Future; import java.util.logging.Logger; import static com.yahoo.vespa.config.server.filedistribution.FileDistributionUtil.fileReferenceExistsOnDisk; @@ -51,39 +54,60 @@ public class ApplicationPackageMaintainer extends ConfigServerMaintainer { @Override protected double maintain() { int attempts = 0; - int failures = 0; - - for (var applicationId : applicationRepository.listApplications()) { - if (shuttingDown()) - break; - - log.finest(() -> "Verifying application package for " + applicationId); - Optional<Session> session = applicationRepository.getActiveSession(applicationId); - if (session.isEmpty()) continue; // App might be deleted after call to listApplications() or not activated yet (bootstrap phase) - - Optional<FileReference> appFileReference = session.get().getApplicationPackageReference(); - if (appFileReference.isPresent()) { - long sessionId = session.get().getSessionId(); - attempts++; - if (!fileReferenceExistsOnDisk(downloadDirectory, appFileReference.get())) { - log.fine(() -> "Downloading application package with file reference " + appFileReference + - " for " + applicationId + " (session " + sessionId + ")"); - - FileReferenceDownload download = new FileReferenceDownload(appFileReference.get(), - this.getClass().getSimpleName(), - false); - if (fileDownloader.getFile(download).isEmpty()) { - failures++; - log.info("Downloading application package (" + appFileReference + ")" + - " for " + applicationId + " (session " + sessionId + ") unsuccessful. " + - "Can be ignored unless it happens many times over a long period of time, retries is expected"); + int[] failures = new int[1]; + + List<Runnable> futureDownloads = new ArrayList<>(); + for (TenantName tenantName : applicationRepository.tenantRepository().getAllTenantNames()) { + for (Session session : applicationRepository.tenantRepository().getTenant(tenantName).getSessionRepository().getRemoteSessions()) { + if (shuttingDown()) + break; + + switch (session.getStatus()) { + case PREPARE, ACTIVATE: + break; + default: continue; + } + + var applicationId = session.getApplicationId(); + log.finest(() -> "Verifying application package for " + applicationId); + + Optional<FileReference> appFileReference = session.getApplicationPackageReference(); + if (appFileReference.isPresent()) { + long sessionId = session.getSessionId(); + attempts++; + if (!fileReferenceExistsOnDisk(downloadDirectory, appFileReference.get())) { + log.fine(() -> "Downloading application package with file reference " + appFileReference + + " for " + applicationId + " (session " + sessionId + ")"); + + FileReferenceDownload download = new FileReferenceDownload(appFileReference.get(), + this.getClass().getSimpleName(), + false); + Future<Optional<File>> futureDownload = fileDownloader.getFutureFileOrTimeout(download); + futureDownloads.add(() -> { + try { + if (futureDownload.get().isPresent()) { + createLocalSessionIfMissing(applicationId, sessionId); + return; + } + } + catch (Exception ignored) { } + failures[0]++; + log.info("Downloading application package (" + appFileReference + ")" + + " for " + applicationId + " (session " + sessionId + ") unsuccessful. " + + "Can be ignored unless it happens many times over a long period of time, retries is expected"); + }); + } + else { + createLocalSessionIfMissing(applicationId, sessionId); } } - createLocalSessionIfMissing(applicationId, sessionId); } } - return asSuccessFactorDeviation(attempts, failures); + + futureDownloads.forEach(Runnable::run); + + return asSuccessFactorDeviation(attempts, failures[0]); } private static FileDownloader createFileDownloader(ApplicationRepository applicationRepository, @@ -92,7 +116,7 @@ public class ApplicationPackageMaintainer extends ConfigServerMaintainer { List<String> otherConfigServersInCluster = getOtherConfigServersInCluster(applicationRepository.configserverConfig()); ConfigSourceSet configSourceSet = new ConfigSourceSet(otherConfigServersInCluster); ConnectionPool connectionPool = new FileDistributionConnectionPool(configSourceSet, supervisor); - return new FileDownloader(connectionPool, supervisor, downloadDirectory, Duration.ofSeconds(300)); + return new FileDownloader(connectionPool, supervisor, downloadDirectory, Duration.ofSeconds(60)); } @Override diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java index 52c11ed0e93..2f0d8b4065d 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java @@ -79,6 +79,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; import java.util.logging.Level; import java.util.logging.Logger; @@ -369,7 +370,7 @@ public class SessionRepository { return session; } - public int deleteExpiredRemoteSessions(Clock clock) { + public int deleteExpiredRemoteSessions(Clock clock, Predicate<Session> sessionIsActiveForApplication) { Duration expiryTime = Duration.ofSeconds(expiryTimeFlag.value()); List<Long> remoteSessionsFromZooKeeper = getRemoteSessionsFromZooKeeper(); log.log(Level.FINE, () -> "Remote sessions for tenant " + tenantName + ": " + remoteSessionsFromZooKeeper); @@ -377,11 +378,11 @@ public class SessionRepository { int deleted = 0; // Avoid deleting too many in one run int deleteMax = (int) Math.min(1000, Math.max(50, remoteSessionsFromZooKeeper.size() * 0.05)); - for (long sessionId : remoteSessionsFromZooKeeper) { + for (Long sessionId : remoteSessionsFromZooKeeper) { Session session = remoteSessionCache.get(sessionId); if (session == null) session = new RemoteSession(tenantName, sessionId, createSessionZooKeeperClient(sessionId)); - if (session.getStatus() == Session.Status.ACTIVATE) continue; + if (session.getStatus() == Session.Status.ACTIVATE && sessionIsActiveForApplication.test(session)) continue; if (sessionHasExpired(session.getCreateTime(), expiryTime, clock)) { log.log(Level.FINE, () -> "Remote session " + sessionId + " for " + tenantName + " has expired, deleting it"); deleteRemoteSessionFromZooKeeper(session); @@ -616,7 +617,7 @@ public class SessionRepository { // ---------------- Common stuff ---------------------------------------------------------------- - public void deleteExpiredSessions(Map<ApplicationId, Long> activeSessions) { + public void deleteExpiredSessions(Predicate<Session> sessionIsActiveForApplication) { log.log(Level.FINE, () -> "Deleting expired local sessions for tenant '" + tenantName + "'"); Set<Long> sessionIdsToDelete = new HashSet<>(); Set<Long> newSessions = findNewSessionsInFileSystem(); @@ -650,8 +651,7 @@ public class SessionRepository { Optional<ApplicationId> applicationId = session.getOptionalApplicationId(); if (applicationId.isEmpty()) continue; - Long activeSession = activeSessions.get(applicationId.get()); - if (activeSession == null || activeSession != sessionId) { + if ( ! sessionIsActiveForApplication.test(session)) { sessionIdsToDelete.add(sessionId); log.log(Level.FINE, () -> "Will delete inactive session " + sessionId + " created " + createTime + " for '" + applicationId + "'"); diff --git a/configserver/src/main/resources/configserver-app/services.xml b/configserver/src/main/resources/configserver-app/services.xml index 2e666089152..069b7ffc496 100644 --- a/configserver/src/main/resources/configserver-app/services.xml +++ b/configserver/src/main/resources/configserver-app/services.xml @@ -111,6 +111,7 @@ </handler> <handler id='com.yahoo.vespa.config.server.http.v2.ApplicationApiHandler' bundle='configserver'> <binding>http://*/application/v2/tenant/*/prepareandactivate</binding> + <binding>http://*/application/v2/tenant/*/prepareandactivate/*</binding> </handler> <handler id='com.yahoo.vespa.config.server.http.v2.SessionContentHandler' bundle='configserver'> <binding>http://*/application/v2/tenant/*/session/*/content/*</binding> diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/filedistribution/FileServerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/filedistribution/FileServerTest.java index 891284a3a0e..e0a58888109 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/filedistribution/FileServerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/filedistribution/FileServerTest.java @@ -206,7 +206,7 @@ public class FileServerTest { super(FileDownloader.emptyConnectionPool(), new Supervisor(new Transport("mock")).setDropEmptyBuffers(true), downloadDirectory, - Duration.ofMillis(100), + Duration.ofMillis(1000), Duration.ofMillis(100)); } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java index 36892860295..d4aa0676c4f 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java @@ -212,15 +212,34 @@ class ApplicationApiHandlerTest { @Test void testActivationFailuresAndRetries() throws Exception { - // Prepare session 2, but fail on hosts; this session will be activated later. - provisioner.activationFailure(new ParentHostUnavailableException("host still booting")); + // Prepare session 2, and activate it successfully. verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), 200, """ { "log": [ ], - "message": "Session 2 for tenant 'test' prepared, but activation failed: host still booting", + "message": "Session 2 for tenant 'test' prepared and activated.", "session-id": "2", + "activated": true, + "tenant": "test", + "url": "http://host:123/application/v2/tenant/test/application/default/environment/prod/region/default/instance/default", + "configChangeActions": { + "restart": [ ], + "refeed": [ ], + "reindex": [ ] + } + } + """); + + // Prepare session 3, but fail on hosts; this session will be activated later. + provisioner.activationFailure(new ParentHostUnavailableException("host still booting")); + verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), + 200, + """ + { + "log": [ ], + "message": "Session 3 for tenant 'test' prepared, but activation failed: host still booting", + "session-id": "3", "activated": false, "tenant": "test", "url": "http://host:123/application/v2/tenant/test/application/default/environment/prod/region/default/instance/default", @@ -232,15 +251,15 @@ class ApplicationApiHandlerTest { } """); - // Prepare session 3, but fail on lock; this session will become outdated later. + // Prepare session 4, but fail on lock; this session will become outdated later. provisioner.activationFailure(new ApplicationLockException("lock timeout")); verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), 200, """ { "log": [ ], - "message": "Session 3 for tenant 'test' prepared, but activation failed: lock timeout", - "session-id": "3", + "message": "Session 4 for tenant 'test' prepared, but activation failed: lock timeout", + "session-id": "4", "activated": false, "tenant": "test", "url": "http://host:123/application/v2/tenant/test/application/default/environment/prod/region/default/instance/default", @@ -263,9 +282,9 @@ class ApplicationApiHandlerTest { } """); - // Retry only activation of session 2, but fail again with hosts. + // Retry only activation of session 3, but fail again with hosts. provisioner.activationFailure(new ParentHostUnavailableException("host still booting")); - verifyResponse(put(2, Map.of()), + verifyResponse(put(3, Map.of()), 409, """ { @@ -274,9 +293,9 @@ class ApplicationApiHandlerTest { } """); - // Retry only activation of session 2, but fail again with lock. + // Retry only activation of session 3, but fail again with lock. provisioner.activationFailure(new ApplicationLockException("lock timeout")); - verifyResponse(put(2, Map.of()), + verifyResponse(put(3, Map.of()), 500, """ { @@ -285,33 +304,33 @@ class ApplicationApiHandlerTest { } """); - // Retry only activation of session 2, and succeed! + // Retry only activation of session 3, and succeed! provisioner.activationFailure(null); - verifyResponse(put(2, Map.of()), + verifyResponse(put(3, Map.of()), 200, """ { - "message": "Session 2 for test.default.default activated" + "message": "Session 3 for test.default.default activated" } """); - // Retry only activation of session 3, but fail because it is now based on an outdated session. - verifyResponse(put(3, Map.of()), + // Retry only activation of session 4, but fail because it is now based on an outdated session. + verifyResponse(put(4, Map.of()), 409, """ { "error-code": "ACTIVATION_CONFLICT", - "message": "app:test.default.default Cannot activate session 3 because the currently active session (2) has changed since session 3 was created (was empty at creation time)" + "message": "app:test.default.default Cannot activate session 4 because the currently active session (3) has changed since session 4 was created (was 2 at creation time)" } """); - // Retry activation of session 2 again, and fail. - verifyResponse(put(2, Map.of()), + // Retry activation of session 3 again, and fail. + verifyResponse(put(3, Map.of()), 400, """ { "error-code": "BAD_REQUEST", - "message": "app:test.default.default Session 2 is already active" + "message": "app:test.default.default Session 3 is already active" } """); } diff --git a/container-search/src/main/java/com/yahoo/prelude/query/parser/AllParser.java b/container-search/src/main/java/com/yahoo/prelude/query/parser/AllParser.java index 0010291de66..43bd175b348 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/parser/AllParser.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/parser/AllParser.java @@ -11,6 +11,7 @@ import com.yahoo.prelude.query.OrItem; import com.yahoo.prelude.query.PhraseItem; import com.yahoo.prelude.query.QueryCanonicalizer; import com.yahoo.prelude.query.RankItem; +import com.yahoo.prelude.query.TrueItem; import com.yahoo.prelude.query.WeakAndItem; import com.yahoo.search.query.QueryTree; import com.yahoo.search.query.parser.ParserEnvironment; @@ -79,8 +80,8 @@ public class AllParser extends SimpleParser { // Combine the items Item topLevel = and; - if (not != null && topLevel != null) { - not.setPositiveItem(topLevel); + if (not != null) { + not.setPositiveItem(topLevel != null ? topLevel : new TrueItem()); topLevel = not; } @@ -130,6 +131,7 @@ public class AllParser extends SimpleParser { if ( ! tokens.skip(MINUS)) return null; if (tokens.currentIsNoIgnore(SPACE)) return null; var itemAndExplicitIndex = indexableItem(); + item = itemAndExplicitIndex.getFirst(); boolean explicitIndex = itemAndExplicitIndex.getSecond(); if (item == null) { diff --git a/container-search/src/main/java/com/yahoo/prelude/query/parser/AnyParser.java b/container-search/src/main/java/com/yahoo/prelude/query/parser/AnyParser.java index efc804fcf1f..1bbc21768b5 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/parser/AnyParser.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/parser/AnyParser.java @@ -12,6 +12,7 @@ import com.yahoo.prelude.query.OrItem; import com.yahoo.prelude.query.PhraseItem; import com.yahoo.prelude.query.RankItem; import com.yahoo.prelude.query.TermItem; +import com.yahoo.prelude.query.TrueItem; import com.yahoo.search.query.parser.ParserEnvironment; import java.util.Iterator; @@ -106,9 +107,8 @@ public class AnyParser extends SimpleParser { } return rank; } else if ((topLevelItem instanceof RankItem) - && (item instanceof RankItem) + && (item instanceof RankItem itemAsRank) && (((RankItem) item).getItem(0) instanceof OrItem)) { - RankItem itemAsRank = (RankItem) item; OrItem or = (OrItem) itemAsRank.getItem(0); ((RankItem) topLevelItem).addItem(0, or); @@ -139,8 +139,10 @@ public class AnyParser extends SimpleParser { if (root instanceof PhraseItem) { root.setFilter(true); } - for (Iterator<Item> i = ((CompositeItem) root).getItemIterator(); i.hasNext();) { - markAllTermsAsFilters(i.next()); + if (root instanceof CompositeItem composite) { + for (Iterator<Item> i = composite.getItemIterator(); i.hasNext(); ) { + markAllTermsAsFilters(i.next()); + } } } } @@ -206,8 +208,7 @@ public class AnyParser extends SimpleParser { return root; } - if (root instanceof RankItem) { - RankItem rootAsRank = (RankItem) root; + if (root instanceof RankItem rootAsRank) { Item firstChild = rootAsRank.getItem(0); if (firstChild instanceof NotItem) { @@ -228,7 +229,6 @@ public class AnyParser extends SimpleParser { } NotItem not = new NotItem(); - not.addPositiveItem(root); not.addNegativeItem(item); return not; diff --git a/container-search/src/main/java/com/yahoo/prelude/query/parser/SimpleParser.java b/container-search/src/main/java/com/yahoo/prelude/query/parser/SimpleParser.java index deab2be9d00..ea0cd2312a6 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/parser/SimpleParser.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/parser/SimpleParser.java @@ -130,14 +130,13 @@ abstract class SimpleParser extends StructuredParser { } } if (not != null && not.getPositiveItem() instanceof TrueItem) { - // Incomplete not, only negatives - - + // Incomplete not, only negatives - simplify when possible if (topLevelItem != null && topLevelItem != not) { // => neutral rank items becomes implicit positives not.addPositiveItem(getItemAsPositiveItem(topLevelItem, not)); return not; - } else { // Only negatives - ignore them - return null; + } else { + return not; } } if (topLevelItem != null) { diff --git a/container-search/src/main/java/com/yahoo/prelude/query/parser/WebParser.java b/container-search/src/main/java/com/yahoo/prelude/query/parser/WebParser.java index 06ea583c53f..75396a8714f 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/parser/WebParser.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/parser/WebParser.java @@ -6,6 +6,7 @@ import com.yahoo.prelude.query.CompositeItem; import com.yahoo.prelude.query.Item; import com.yahoo.prelude.query.NotItem; import com.yahoo.prelude.query.OrItem; +import com.yahoo.prelude.query.TrueItem; import com.yahoo.prelude.query.WordItem; import com.yahoo.search.query.parser.ParserEnvironment; @@ -69,8 +70,8 @@ public class WebParser extends AllParser { if (or != null) topLevel = or; - if (not != null && topLevel != null) { - not.setPositiveItem(topLevel); + if (not != null) { + not.setPositiveItem(topLevel != null ? topLevel : new TrueItem()); topLevel = not; } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java index bef766e7ef9..70f6e405a92 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java @@ -15,6 +15,7 @@ public class ConversionContext { private final String destination; private final CompiledQueryProfileRegistry registry; private final Map<String, Embedder> embedders; + private final Map<String, String> contextValues; private final Language language; public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder, @@ -30,6 +31,7 @@ public class ConversionContext { this.embedders = embedders; this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language")) : Language.UNKNOWN; + this.contextValues = context; } /** Returns the local name of the field which will receive the converted value (or null when this is empty) */ @@ -44,6 +46,9 @@ public class ConversionContext { /** Returns the language, which is never null but may be UNKNOWN */ Language language() { return language; } + /** Returns a read-only map of context key-values which can be looked up during conversion. */ + Map<String,String> contextValues() { return contextValues; } + /** Returns an empty context */ public static ConversionContext empty() { return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of()); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index cfadd79de8f..e16f8e7b0cd 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -48,7 +48,8 @@ public class TensorFieldType extends FieldType { @Override public Object convertFrom(Object o, ConversionContext context) { if (o instanceof SubstituteString) return new SubstituteStringTensor((SubstituteString) o, type); - return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, context.language()); + return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, + context.language(), context.contextValues()); } public static TensorFieldType fromTypeString(String s) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java index c9f935e5f52..25a5c277dce 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java @@ -44,7 +44,8 @@ public class RankProfileInputProperties extends Properties { value = tensorConverter.convertTo(expectedType, name.last(), value, - query.getModel().getLanguage()); + query.getModel().getLanguage(), + context); } } catch (IllegalArgumentException e) { diff --git a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java index 6da53ae699c..94f92c7fd48 100644 --- a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java +++ b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java @@ -19,7 +19,8 @@ import java.util.regex.Pattern; */ public class TensorConverter { - private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); + private static final Pattern embedderArgumentAndQuotedTextRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*([\"'].*[\"'])"); + private static final Pattern embedderArgumentAndReferenceRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*(@.*)"); private final Map<String, Embedder> embedders; @@ -27,8 +28,9 @@ public class TensorConverter { this.embedders = embedders; } - public Tensor convertTo(TensorType type, String key, Object value, Language language) { - var context = new Embedder.Context(key).setLanguage(language); + public Tensor convertTo(TensorType type, String key, Object value, Language language, + Map<String, String> contextValues) { + var context = new Embedder.Context(key).setLanguage(language).setContextValues(contextValues); Tensor tensor = toTensor(type, value, context); if (tensor == null) return null; if (! tensor.type().isAssignableTo(type)) @@ -55,16 +57,16 @@ public class TensorConverter { String embedderId; // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") - Matcher matcher = embedderArgumentRegexp.matcher(argument); - if (matcher.matches()) { + Matcher matcher; + if (( matcher = embedderArgumentAndQuotedTextRegexp.matcher(argument)).matches()) { embedderId = matcher.group(1); + embedder = requireEmbedder(embedderId); argument = matcher.group(2); - if ( ! embedders.containsKey(embedderId)) { - throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + - "Valid embedders are " + validEmbedders(embedders)); - } - embedder = embedders.get(embedderId); - } else if (embedders.size() == 0) { + } else if (( matcher = embedderArgumentAndReferenceRegexp.matcher(argument)).matches()) { + embedderId = matcher.group(1); + embedder = requireEmbedder(embedderId); + argument = matcher.group(2); + } else if (embedders.isEmpty()) { throw new IllegalStateException("No embedders provided"); // should never happen } else if (embedders.size() > 1) { throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + @@ -74,19 +76,35 @@ public class TensorConverter { embedderId = entry.getKey(); embedder = entry.getValue(); } - return embedder.embed(removeQuotes(argument), embedderContext.copy().setEmbedderId(embedderId), type); + return embedder.embed(resolve(argument, embedderContext), embedderContext.copy().setEmbedderId(embedderId), type); } - private static String removeQuotes(String s) { - if (s.startsWith("'") && s.endsWith("'")) { + private Embedder requireEmbedder(String embedderId) { + if ( ! embedders.containsKey(embedderId)) + throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(embedders)); + return embedders.get(embedderId); + } + + private static String resolve(String s, Embedder.Context embedderContext) { + if (s.startsWith("'") && s.endsWith("'")) return s.substring(1, s.length() - 1); - } - if (s.startsWith("\"") && s.endsWith("\"")) { + if (s.startsWith("\"") && s.endsWith("\"")) return s.substring(1, s.length() - 1); - } + if (s.startsWith("@")) + return resolveReference(s, embedderContext); return s; } + private static String resolveReference(String s, Embedder.Context embedderContext) { + String referenceKey = s.substring(1); + String referencedValue = embedderContext.getContextValues().get(referenceKey); + if (referencedValue == null) + throw new IllegalArgumentException("Could not resolve query parameter reference '" + referenceKey + + "' used in an embed() argument"); + return referencedValue; + } + private static String validEmbedders(Map<String, Embedder> embedders) { List<String> embedderIds = new ArrayList<>(); embedders.forEach((key, value) -> embedderIds.add(key)); diff --git a/container-search/src/test/java/com/yahoo/prelude/query/parser/test/ParseTestCase.java b/container-search/src/test/java/com/yahoo/prelude/query/parser/test/ParseTestCase.java index 4383be184fa..156c34e5005 100644 --- a/container-search/src/test/java/com/yahoo/prelude/query/parser/test/ParseTestCase.java +++ b/container-search/src/test/java/com/yahoo/prelude/query/parser/test/ParseTestCase.java @@ -21,6 +21,7 @@ import com.yahoo.prelude.query.RankItem; import com.yahoo.prelude.query.SubstringItem; import com.yahoo.prelude.query.SuffixItem; import com.yahoo.prelude.query.TaggableItem; +import com.yahoo.prelude.query.TrueItem; import com.yahoo.prelude.query.WordItem; import com.yahoo.language.process.SpecialTokens; import com.yahoo.prelude.query.parser.TestLinguistics; @@ -262,17 +263,28 @@ public class ParseTestCase { @Test void testNotOnly() { - tester.assertParsed(null, "-foobar", Query.Type.ALL); + Item item = tester.assertParsed("-foobar", "-foobar", Query.Type.ALL); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem) item).getPositiveItem() instanceof TrueItem); } @Test - void testMultipleNotsOnlt() { - tester.assertParsed(null, "-foo -bar -foobar", Query.Type.ALL); + void testNotOnlyAny() { + Item item = tester.assertParsed("-foobar", "-foobar", Query.Type.ANY); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem) item).getPositiveItem() instanceof TrueItem); + } + + @Test + void testMultipleNotsOnly() { + Item item = tester.assertParsed("-foo -bar -foobar", "-foo -bar -foobar", Query.Type.ALL); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem) item).getPositiveItem() instanceof TrueItem); } @Test void testOnlyNotComposite() { - tester.assertParsed(null, "-(foo bar baz)", Query.Type.ALL); + tester.assertParsed("-(AND foo bar baz)", "-(foo bar baz)", Query.Type.ALL); } @Test @@ -391,7 +403,7 @@ public class ParseTestCase { @Test void testMinusAndPluses() { - tester.assertParsed(null, "--test+-if", Query.Type.ANY); + tester.assertParsed("-(AND test if)", "--test+-if", Query.Type.ANY); } @Test @@ -1305,7 +1317,9 @@ public class ParseTestCase { @Test void testNotFilterEmptyQuery() { - tester.assertParsed(null, "", "-foo", Query.Type.ANY); + Item item = tester.assertParsed("-|foo", "", "-foo", Query.Type.ANY); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem) item).getPositiveItem() instanceof TrueItem); } @Test @@ -1380,7 +1394,7 @@ public class ParseTestCase { @Test void testMultitermNotFilterEmptyQuery() { - tester.assertParsed(null, "", "-foo -foz", Query.Type.ANY); + tester.assertParsed("-|foo -|foz", "", "-foo -foz", Query.Type.ANY); } @Test @@ -2320,17 +2334,19 @@ public class ParseTestCase { @Test void testNotOnlyWeb() { - tester.assertParsed(null, "-foobar", Query.Type.WEB); + Item item = tester.assertParsed("-foobar", "-foobar", Query.Type.WEB); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem)item).getPositiveItem() instanceof TrueItem); } @Test void testMultipleNotsOnltWeb() { - tester.assertParsed(null, "-foo -bar -foobar", Query.Type.WEB); + tester.assertParsed("-foo -bar -foobar", "-foo -bar -foobar", Query.Type.WEB); } @Test void testOnlyNotCompositeWeb() { - tester.assertParsed(null, "-(foo bar baz)", Query.Type.WEB); + tester.assertParsed("-(AND foo bar baz)", "-(foo bar baz)", Query.Type.WEB); } @Test diff --git a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java index 90e21e5f3b0..429b8d1c6cb 100644 --- a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java +++ b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java @@ -185,6 +185,21 @@ public class RankProfileInputTest { assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders, Language.UNKNOWN.languageCode()); } + @Test + void testUnembeddedTensorRankFeatureInRequestReferencedFromAParameter() { + String text = "text to embed into a tensor"; + Tensor embedding1 = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]"); + + Map<String, Embedder> embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1) + ); + assertEmbedQuery("embed(@param1)", embedding1, embedders, null, text); + assertEmbedQuery("embed(emb1, @param1)", embedding1, embedders, null, text); + assertEmbedQueryFails("embed(emb1, @noSuchParam)", embedding1, embedders, + "Could not resolve query parameter reference 'noSuchParam' " + + "used in an embed() argument"); + } + private Query createTensor1Query(String tensorString, String profile, String additionalParams) { return new Query.Builder() .setSchemaInfo(createSchemaInfo()) @@ -202,18 +217,24 @@ public class RankProfileInputTest { } private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders) { - assertEmbedQuery(embed, expected, embedders, null); + assertEmbedQuery(embed, expected, embedders, null, null); } private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language) { + assertEmbedQuery(embed, expected, embedders, language, null); + } + private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language, String param1Value) { String languageParam = language == null ? "" : "&language=" + language; + String param1 = param1Value == null ? "" : "¶m1=" + urlEncode(param1Value); + String destination = "query(myTensor4)"; Query query = new Query.Builder().setRequest(HttpRequest.createTestRequest( "?" + urlEncode("ranking.features." + destination) + "=" + urlEncode(embed) + "&ranking=commonProfile" + - languageParam, + languageParam + + param1, com.yahoo.jdisc.http.HttpRequest.Method.GET)) .setSchemaInfo(createSchemaInfo()) .setQueryProfile(createQueryProfile()) @@ -230,7 +251,7 @@ public class RankProfileInputTest { if (t.getMessage().equals(errMsg)) return; t = t.getCause(); } - fail("Error '" + errMsg + "' not thrown"); + fail("Exception with message '" + errMsg + "' not thrown"); } private CompiledQueryProfile createQueryProfile() { diff --git a/dependency-versions/pom.xml b/dependency-versions/pom.xml index cbeb4d01fe4..f4ad767c18a 100644 --- a/dependency-versions/pom.xml +++ b/dependency-versions/pom.xml @@ -66,7 +66,7 @@ <!-- Athenz dependencies. Make sure these dependencies match those in Vespa's internal repositories --> <athenz.vespa.version>1.11.50</athenz.vespa.version> - <aws-sdk.vespa.version>1.12.639</aws-sdk.vespa.version> + <aws-sdk.vespa.version>1.12.643</aws-sdk.vespa.version> <!-- Athenz END --> <!-- WARNING: If you change curator version, you also need to update @@ -90,7 +90,7 @@ <commons-compress.vespa.version>1.25.0</commons-compress.vespa.version> <commons-cli.vespa.version>1.6.0</commons-cli.vespa.version> <curator.vespa.version>5.6.0</curator.vespa.version> - <dropwizard.metrics.vespa.version>4.2.23</dropwizard.metrics.vespa.version> <!-- ZK 3.9.1 requires this --> + <dropwizard.metrics.vespa.version>4.2.24</dropwizard.metrics.vespa.version> <!-- ZK 3.9.1 requires this --> <eclipse-collections.vespa.version>11.1.0</eclipse-collections.vespa.version> <eclipse-sisu.vespa.version>0.9.0.M2</eclipse-sisu.vespa.version> <failureaccess.vespa.version>1.0.2</failureaccess.vespa.version> @@ -120,7 +120,7 @@ <mimepull.vespa.version>1.10.0</mimepull.vespa.version> <mockito.vespa.version>5.9.0</mockito.vespa.version> <mojo-executor.vespa.version>2.4.0</mojo-executor.vespa.version> - <netty.vespa.version>4.1.105.Final</netty.vespa.version> + <netty.vespa.version>4.1.106.Final</netty.vespa.version> <netty-tcnative.vespa.version>2.0.62.Final</netty-tcnative.vespa.version> <onnxruntime.vespa.version>1.16.3</onnxruntime.vespa.version> <opennlp.vespa.version>2.3.1</opennlp.vespa.version> diff --git a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java index a9cd3cc87a8..dec84e46b77 100644 --- a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java +++ b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java @@ -150,34 +150,6 @@ public class TokenBuffer { return nesting; } - public Token prefetchScalar(String name) { - int localNesting = nesting(); - int nestingBarrier = localNesting; - Token toReturn = null; - Iterator<Token> i; - - if (name.equals(currentName()) && current().isScalarValue()) { - toReturn = tokens.get(position); - } else { - i = tokens.iterator(); - i.next(); // just ignore the first value, as we know it's not what - // we're looking for, and it's nesting effect is already - // included - while (i.hasNext()) { - Token t = i.next(); - if (localNesting == nestingBarrier && name.equals(t.name) && t.token.isScalarValue()) { - toReturn = t; - break; - } - localNesting += nestingOffset(t.token); - if (localNesting < nestingBarrier) { - break; - } - } - } - return toReturn; - } - public void skipToRelativeNesting(int relativeNesting) { int initialNesting = nesting(); do { diff --git a/document/src/main/java/com/yahoo/document/json/readers/MapReader.java b/document/src/main/java/com/yahoo/document/json/readers/MapReader.java index 6c850fe4320..b45a0001fd1 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/MapReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/MapReader.java @@ -90,45 +90,39 @@ public class MapReader { @SuppressWarnings({ "rawtypes", "unchecked" }) public static ValueUpdate createMapUpdate(TokenBuffer buffer, DataType currentLevel, - FieldValue keyParent, - FieldValue topLevelKey, boolean ignoreUndefinedFields) { - TokenBuffer.Token element = buffer.prefetchScalar(UPDATE_ELEMENT); + if ( ! JsonToken.START_OBJECT.equals(buffer.current())) + throw new IllegalArgumentException("Expected object for match update, got " + buffer.current()); + buffer.next(); + + FieldValue key = null; + ValueUpdate update; + if (UPDATE_ELEMENT.equals(buffer.currentName())) { + key = keyTypeForMapUpdate(buffer.currentText(), currentLevel); buffer.next(); } - FieldValue key = keyTypeForMapUpdate(element, currentLevel); - if (keyParent != null) { - ((CollectionFieldValue) keyParent).add(key); - } - // structure is: [(match + element)*, (element + action)] - // match will always have element, and either match or action - if (!UPDATE_MATCH.equals(buffer.currentName())) { - // we have reached an action... - if (topLevelKey == null) { - return ValueUpdate.createMap(key, readSingleUpdate(buffer, valueTypeForMapUpdate(currentLevel), buffer.currentName(), ignoreUndefinedFields)); - } else { - return ValueUpdate.createMap(topLevelKey, readSingleUpdate(buffer, valueTypeForMapUpdate(currentLevel), buffer.currentName(), ignoreUndefinedFields)); - } - } else { - // next level of matching - if (topLevelKey == null) { - return createMapUpdate(buffer, valueTypeForMapUpdate(currentLevel), key, key, ignoreUndefinedFields); - } else { - return createMapUpdate(buffer, valueTypeForMapUpdate(currentLevel), key, topLevelKey, ignoreUndefinedFields); - } + update = UPDATE_MATCH.equals(buffer.currentName()) ? createMapUpdate(buffer, valueTypeForMapUpdate(currentLevel), ignoreUndefinedFields) + : readSingleUpdate(buffer, valueTypeForMapUpdate(currentLevel), buffer.currentName(), ignoreUndefinedFields); + buffer.next(); + + if (key == null) { + if ( ! UPDATE_ELEMENT.equals(buffer.currentName())) + throw new IllegalArgumentException("Expected match element, got " + buffer.current()); + key = keyTypeForMapUpdate(buffer.currentText(), currentLevel); + buffer.next(); } + + if ( ! JsonToken.END_OBJECT.equals(buffer.current())) + throw new IllegalArgumentException("Expected object end for match update, got " + buffer.current()); + + return ValueUpdate.createMap(key, update); } @SuppressWarnings("rawtypes") public static ValueUpdate createMapUpdate(TokenBuffer buffer, Field field, boolean ignoreUndefinedFields) { - buffer.next(); - MapValueUpdate m = (MapValueUpdate) MapReader.createMapUpdate(buffer, field.getDataType(), null, null, ignoreUndefinedFields); - buffer.next(); - // must generate the field value in parallell with the actual - return m; - + return MapReader.createMapUpdate(buffer, field.getDataType(), ignoreUndefinedFields); } private static DataType valueTypeForMapUpdate(DataType parentType) { @@ -143,14 +137,14 @@ public class MapReader { } } - private static FieldValue keyTypeForMapUpdate(TokenBuffer.Token element, DataType expectedType) { + private static FieldValue keyTypeForMapUpdate(String elementText, DataType expectedType) { FieldValue v; if (expectedType instanceof ArrayDataType) { - v = new IntegerFieldValue(Integer.valueOf(element.text)); + v = new IntegerFieldValue(Integer.valueOf(elementText)); } else if (expectedType instanceof WeightedSetDataType) { - v = ((WeightedSetDataType) expectedType).getNestedType().createFieldValue(element.text); + v = ((WeightedSetDataType) expectedType).getNestedType().createFieldValue(elementText); } else if (expectedType instanceof MapDataType) { - v = ((MapDataType) expectedType).getKeyType().createFieldValue(element.text); + v = ((MapDataType) expectedType).getKeyType().createFieldValue(elementText); } else { throw new IllegalArgumentException("Container type " + expectedType + " not supported for match update."); } diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 8a45fe95fa2..5a9f02c790d 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -31,6 +31,7 @@ import com.yahoo.document.datatypes.StringFieldValue; import com.yahoo.document.datatypes.Struct; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.datatypes.WeightedSet; +import com.yahoo.document.fieldpathupdate.FieldPathUpdate; import com.yahoo.document.internal.GeoPosType; import com.yahoo.document.json.readers.DocumentParseInfo; import com.yahoo.document.json.readers.VespaJsonDocumentReader; @@ -62,6 +63,7 @@ import org.junit.Test; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Collections; @@ -82,6 +84,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -151,6 +154,13 @@ public class JsonReaderTestCase { types.registerDocumentType(x); } { + DocumentType x = new DocumentType("testArrayOfArrayOfInt"); + DataType inner = new ArrayDataType(DataType.INT); + DataType outer = new ArrayDataType(inner); + x.addField(new Field("arrayOfArrayOfInt", outer)); + types.registerDocumentType(x); + } + { DocumentType x = new DocumentType("testsinglepos"); DataType d = PositionDataType.INSTANCE; x.addField(new Field("singlepos", d)); @@ -211,103 +221,110 @@ public class JsonReaderTestCase { } @Test - public void readSingleDocumentPut() { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:smoke::doc1',", - " 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse'", - " }", - "}")); - DocumentPut put = (DocumentPut) r.readSingleDocument(DocumentOperationType.PUT, - "id:unittest:smoke::doc1").operation(); - smokeTestDoc(put.getDocument()); + public void readSingleDocumentPut() throws IOException { + Document doc = docFromJson(""" + { + "put": "id:unittest:smoke::doc1", + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + } + } + """); + smokeTestDoc(doc); } @Test - public final void readSingleDocumentUpdate() { - JsonReader r = createReader(inputJson("{ 'update': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': {", - " 'assign': 'orOther' }}}")); - DocumentUpdate doc = (DocumentUpdate) r.readSingleDocument(DocumentOperationType.UPDATE, "id:unittest:smoke::whee").operation(); + public final void readSingleDocumentUpdate() throws IOException { + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:smoke::whee", + "fields": { + "something": { + "assign": "orOther" + } + } + } + """); FieldUpdate f = doc.getFieldUpdate("something"); assertEquals(1, f.size()); assertTrue(f.getValueUpdate(0) instanceof AssignValueUpdate); + assertEquals(new StringFieldValue("orOther"), f.getValueUpdate(0).getValue()); } @Test - public void readClearField() { - JsonReader r = createReader(inputJson("{ 'update': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'int1': {", - " 'assign': null }}}")); - DocumentUpdate doc = (DocumentUpdate) r.readSingleDocument(DocumentOperationType.UPDATE, "id:unittest:smoke::whee").operation(); + public void readClearField() throws IOException { + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:smoke::whee", + "fields": { + "int1": { + "assign": null + } + } + } + """); FieldUpdate f = doc.getFieldUpdate("int1"); assertEquals(1, f.size()); assertTrue(f.getValueUpdate(0) instanceof ClearValueUpdate); assertNull(f.getValueUpdate(0).getValue()); } - @Test public void smokeTest() throws IOException { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:smoke::doc1',", - " 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse'", - " }", - "}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - smokeTestDoc(put.getDocument()); + Document doc = docFromJson(""" + { + "put": "id:unittest:smoke::doc1", + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + } + } + """); + smokeTestDoc(doc); } @Test public void docIdLookaheadTest() throws IOException { - JsonReader r = createReader(inputJson( - "{ 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse'", - " },", - " 'put': 'id:unittest:smoke::doc1'", - " }", - "}")); - - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - smokeTestDoc(put.getDocument()); + Document doc = docFromJson(""" + { + "put": "id:unittest:smoke::doc1", + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + } + } + """); + smokeTestDoc(doc); } - @Test public void emptyDocTest() throws IOException { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:smoke::whee', 'fields': {}}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - assertEquals("id:unittest:smoke::whee", parseInfo.documentId.toString()); + Document doc = docFromJson(""" + { + "put": "id:unittest:smoke::whee", + "fields": { } + }"""); + assertEquals(new Document(types.getDocumentType("smoke"), new DocumentId("id:unittest:smoke::whee")), + doc); } @Test public void testStruct() throws IOException { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:mirrors::whee',", - " 'fields': {", - " 'skuggsjaa': {", - " 'sandra': 'person',", - " 'cloud': 'another person' }}}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - Document doc = put.getDocument(); + Document doc = docFromJson(""" + { + "put": "id:unittest:mirrors::whee", + "fields": { + "skuggsjaa": { + "sandra": "person", + "cloud": "another person" + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("skuggsjaa")); assertSame(Struct.class, f.getClass()); Struct s = (Struct) f; @@ -326,13 +343,20 @@ public class JsonReaderTestCase { @Test public void testStructUpdate() throws IOException { - DocumentUpdate put = parseUpdate(inputJson("{ 'update': 'id:unittest:mirrors:g=test:whee',", - " 'create': true,", - " 'fields': {", - " 'skuggsjaa': {", - " 'assign': {", - " 'sandra': 'person',", - " 'cloud': 'another person' }}}}")); + DocumentUpdate put = parseUpdate(""" + { + "update": "id:unittest:mirrors:g=test:whee", + "create": true, + "fields": { + "skuggsjaa": { + "assign": { + "sandra": "person", + "cloud": "another person" + } + } + } + } + """); assertEquals(1, put.fieldUpdates().size()); FieldUpdate fu = put.fieldUpdates().iterator().next(); assertEquals(1, fu.getValueUpdates().size()); @@ -351,11 +375,17 @@ public class JsonReaderTestCase { @Test public final void testEmptyStructUpdate() throws IOException { - DocumentUpdate put = parseUpdate(inputJson("{ 'update': 'id:unittest:mirrors:g=test:whee',", - " 'create': true,", - " 'fields': { ", - " 'skuggsjaa': {", - " 'assign': { } }}}")); + DocumentUpdate put = parseUpdate(""" + { + "update": "id:unittest:mirrors:g=test:whee", + "create": true, + "fields": { + "skuggsjaa": { + "assign": { } + } + } + } + """); assertEquals(1, put.fieldUpdates().size()); FieldUpdate fu = put.fieldUpdates().iterator().next(); assertEquals(1, fu.getValueUpdates().size()); @@ -373,23 +403,37 @@ public class JsonReaderTestCase { @Test public void testUpdateArray() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testarray::whee',", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person' ]}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testarray::whee", + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + } + """); checkSimpleArrayAdd(doc); } @Test public void testUpdateWeighted() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testset::whee',", - " 'fields': {", - " 'actualset': {", - " 'add': {", - " 'person': 37,", - " 'another person': 41 }}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testset::whee", + "fields": { + "actualset": { + "add": { + "person": 37, + "another person": 41 + } + } + } + } + """); Map<String, Integer> weights = new HashMap<>(); FieldUpdate x = doc.getFieldUpdate("actualset"); @@ -409,12 +453,34 @@ public class JsonReaderTestCase { @Test public void testUpdateMatch() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testset::whee',", - " 'fields': {", - " 'actualset': {", - " 'match': {", - " 'element': 'person',", - " 'increment': 13 }}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testset::whee", + "fields": { + "actualset": { + "match": { + "element": "person", + "increment": 13 + } + } + } + } + """); + + DocumentUpdate otherDoc = parseUpdate(""" + { + "update": "id:unittest:testset::whee", + "fields": { + "actualset": { + "match": { + "increment": 13, + "element": "person" + } + } + } + }"""); + + assertEquals(doc, otherDoc); Map<String, Tuple2<Number, String>> matches = new HashMap<>(); FieldUpdate x = doc.getFieldUpdate("actualset"); @@ -437,21 +503,28 @@ public class JsonReaderTestCase { @Test public void testArithmeticOperators() throws IOException { Tuple2[] operations = new Tuple2[] { - new Tuple2<String, Operator>(UPDATE_DECREMENT, - ArithmeticValueUpdate.Operator.SUB), - new Tuple2<String, Operator>(UPDATE_DIVIDE, + new Tuple2<>(UPDATE_DECREMENT, + ArithmeticValueUpdate.Operator.SUB), + new Tuple2<>(UPDATE_DIVIDE, ArithmeticValueUpdate.Operator.DIV), - new Tuple2<String, Operator>(UPDATE_INCREMENT, + new Tuple2<>(UPDATE_INCREMENT, ArithmeticValueUpdate.Operator.ADD), - new Tuple2<String, Operator>(UPDATE_MULTIPLY, + new Tuple2<>(UPDATE_MULTIPLY, ArithmeticValueUpdate.Operator.MUL) }; for (Tuple2<String, Operator> operator : operations) { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testset::whee',", - " 'fields': {", - " 'actualset': {", - " 'match': {", - " 'element': 'person',", - " '" + (String) operator.first + "': 13 }}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testset::whee", + "fields": { + "actualset": { + "match": { + "element": "person", + "%s": 13 + } + } + } + } + """.formatted(operator.first)); Map<String, Tuple2<Number, Operator>> matches = new HashMap<>(); FieldUpdate x = doc.getFieldUpdate("actualset"); @@ -475,12 +548,19 @@ public class JsonReaderTestCase { @SuppressWarnings("rawtypes") @Test public void testArrayIndexing() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testarray::whee',", - " 'fields': {", - " 'actualarray': {", - " 'match': {", - " 'element': 3,", - " 'assign': 'nalle' }}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testarray::whee", + "fields": { + "actualarray": { + "match": { + "element": 3, + "assign": "nalle" + } + } + } + } + """); Map<Number, String> matches = new HashMap<>(); FieldUpdate x = doc.getFieldUpdate("actualarray"); @@ -488,7 +568,7 @@ public class JsonReaderTestCase { MapValueUpdate adder = (MapValueUpdate) v; final Number key = ((IntegerFieldValue) adder.getValue()) .getNumber(); - String op = ((StringFieldValue) ((AssignValueUpdate) adder.getUpdate()) + String op = ((StringFieldValue) adder.getUpdate() .getValue()).getString(); matches.put(key, op); } @@ -515,11 +595,17 @@ public class JsonReaderTestCase { @Test public void testWeightedSet() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testset::whee',", - " 'fields': {", - " 'actualset': {", - " 'nalle': 2,", - " 'tralle': 7 }}}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testset::whee", + "fields": { + "actualset": { + "nalle": 2, + "tralle": 7 + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("actualset")); assertSame(WeightedSet.class, f.getClass()); WeightedSet<?> w = (WeightedSet<?>) f; @@ -530,11 +616,17 @@ public class JsonReaderTestCase { @Test public void testArray() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testarray::whee',", - " 'fields': {", - " 'actualarray': [", - " 'nalle',", - " 'tralle' ]}}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testarray::whee", + "fields": { + "actualarray": [ + "nalle", + "tralle" + ] + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("actualarray")); assertSame(Array.class, f.getClass()); Array<?> a = (Array<?>) f; @@ -545,11 +637,17 @@ public class JsonReaderTestCase { @Test public void testMap() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testmap::whee',", - " 'fields': {", - " 'actualmap': {", - " 'nalle': 'kalle',", - " 'tralle': 'skalle' }}}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testmap::whee", + "fields": { + "actualmap": { + "nalle": "kalle", + "tralle": "skalle" + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("actualmap")); assertSame(MapFieldValue.class, f.getClass()); MapFieldValue<?, ?> m = (MapFieldValue<?, ?>) f; @@ -560,11 +658,23 @@ public class JsonReaderTestCase { @Test public void testOldMap() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testmap::whee',", - " 'fields': {", - " 'actualmap': [", - " { 'key': 'nalle', 'value': 'kalle'},", - " { 'key': 'tralle', 'value': 'skalle'} ]}}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testmap::whee", + "fields": { + "actualmap": [ + { + "key": "nalle", + "value": "kalle" + }, + { + "key": "tralle", + "value": "skalle" + } + ] + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("actualmap")); assertSame(MapFieldValue.class, f.getClass()); MapFieldValue<?, ?> m = (MapFieldValue<?, ?>) f; @@ -575,9 +685,14 @@ public class JsonReaderTestCase { @Test public void testPositionPositive() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'singlepos': 'N63.429722;E10.393333' }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "singlepos": "N63.429722;E10.393333" + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("singlepos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -586,9 +701,17 @@ public class JsonReaderTestCase { @Test public void testPositionOld() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'singlepos': {'x':10393333,'y':63429722} }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "singlepos": { + "x": 10393333, + "y": 63429722 + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("singlepos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -597,9 +720,17 @@ public class JsonReaderTestCase { @Test public void testGeoPosition() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'singlepos': {'lat':63.429722,'lng':10.393333} }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "singlepos": { + "lat": 63.429722, + "lng": 10.393333 + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("singlepos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -608,9 +739,17 @@ public class JsonReaderTestCase { @Test public void testGeoPositionNoAbbreviations() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'singlepos': {'latitude':63.429722,'longitude':10.393333} }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "singlepos": { + "latitude": 63.429722, + "longitude": 10.393333 + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("singlepos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -619,9 +758,14 @@ public class JsonReaderTestCase { @Test public void testPositionGeoPos() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'geopos': 'N63.429722;E10.393333' }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "geopos": "N63.429722;E10.393333" + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("geopos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -631,9 +775,17 @@ public class JsonReaderTestCase { @Test public void testPositionOldGeoPos() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'geopos': {'x':10393333,'y':63429722} }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "geopos": { + "x": 10393333, + "y": 63429722 + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("geopos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -643,9 +795,17 @@ public class JsonReaderTestCase { @Test public void testGeoPositionGeoPos() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'geopos': {'lat':63.429722,'lng':10.393333} }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "geopos": { + "lat": 63.429722, + "lng": 10.393333 + } + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("geopos")); assertSame(Struct.class, f.getClass()); assertEquals(10393333, PositionDataType.getXValue(f).getInteger()); @@ -656,9 +816,14 @@ public class JsonReaderTestCase { @Test public void testPositionNegative() throws IOException { - Document doc = docFromJson(inputJson("{ 'put': 'id:unittest:testsinglepos::bamf',", - " 'fields': {", - " 'singlepos': 'W46.63;S23.55' }}")); + Document doc = docFromJson(""" + { + "put": "id:unittest:testsinglepos::bamf", + "fields": { + "singlepos": "W46.63;S23.55" + } + } + """); FieldValue f = doc.getFieldValue(doc.getField("singlepos")); assertSame(Struct.class, f.getClass()); assertEquals(-46630000, PositionDataType.getXValue(f).getInteger()); @@ -682,14 +847,14 @@ public class JsonReaderTestCase { } private String fieldStringFromBase64RawContent(String base64data) throws IOException { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:testraw::whee',", - " 'fields': {", - " 'actualraw': '" + base64data + "' }}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - Document doc = put.getDocument(); + Document doc = docFromJson(""" + { + "put": "id:unittest:testraw::whee", + "fields": { + "actualraw": "%s" + } + } + """.formatted(base64data)); FieldValue f = doc.getFieldValue(doc.getField("actualraw")); assertSame(Raw.class, f.getClass()); Raw s = (Raw) f; @@ -698,15 +863,16 @@ public class JsonReaderTestCase { @Test public void testMapStringToArrayOfInt() throws IOException { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:testMapStringToArrayOfInt::whee',", - " 'fields': {", - " 'actualMapStringToArrayOfInt': {", - " 'bamse': [1, 2, 3] }}}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - Document doc = put.getDocument(); + Document doc = docFromJson(""" + { + "put": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": { + "bamse": [1, 2, 3] + } + } + } + """); FieldValue f = doc.getFieldValue("actualMapStringToArrayOfInt"); assertSame(MapFieldValue.class, f.getClass()); MapFieldValue<?, ?> m = (MapFieldValue<?, ?>) f; @@ -719,15 +885,19 @@ public class JsonReaderTestCase { @Test public void testOldMapStringToArrayOfInt() throws IOException { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:testMapStringToArrayOfInt::whee',", - " 'fields': {", - " 'actualMapStringToArrayOfInt': [", - " { 'key': 'bamse', 'value': [1, 2, 3] } ]}}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - Document doc = put.getDocument(); + Document doc = docFromJson(""" + { + "put": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": [ + { + "key": "bamse", + "value": [1, 2, 3] + } + ] + } + } + """); FieldValue f = doc.getFieldValue("actualMapStringToArrayOfInt"); assertSame(MapFieldValue.class, f.getClass()); MapFieldValue<?, ?> m = (MapFieldValue<?, ?>) f; @@ -740,10 +910,16 @@ public class JsonReaderTestCase { @Test public void testAssignToString() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': {", - " 'assign': 'orOther' }}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:smoke::whee", + "fields": { + "something": { + "assign": "orOther" + } + } + } + """); FieldUpdate f = doc.getFieldUpdate("something"); assertEquals(1, f.size()); AssignValueUpdate a = (AssignValueUpdate) f.getValueUpdate(0); @@ -751,11 +927,189 @@ public class JsonReaderTestCase { } @Test + public void testNestedArrayMatch() throws IOException { + DocumentUpdate nested = parseUpdate(""" + { + "update": "id:unittest:testArrayOfArrayOfInt::whee", + "fields": { + "arrayOfArrayOfInt": { + "match": { + "element": 1, + "match": { + "element": 2, + "assign": 3 + } + } + } + } + } + """); + + DocumentUpdate equivalent = parseUpdate(""" + { + "update": "id:unittest:testArrayOfArrayOfInt::whee", + "fields": { + "arrayOfArrayOfInt": { + "match": { + "match": { + "assign": 3, + "element": 2 + }, + "element": 1 + } + } + } + } + """); + + assertEquals(nested, equivalent); + assertEquals(1, nested.fieldUpdates().size()); + FieldUpdate fu = nested.fieldUpdates().iterator().next(); + assertEquals(1, fu.getValueUpdates().size()); + MapValueUpdate mvu = (MapValueUpdate) fu.getValueUpdate(0); + assertEquals(new IntegerFieldValue(1), mvu.getValue()); + MapValueUpdate nvu = (MapValueUpdate) mvu.getUpdate(); + assertEquals(new IntegerFieldValue(2), nvu.getValue()); + AssignValueUpdate avu = (AssignValueUpdate) nvu.getUpdate(); + assertEquals(new IntegerFieldValue(3), avu.getValue()); + + Document doc = docFromJson(""" + { + "put": "id:unittest:testArrayOfArrayOfInt::whee", + "fields": { + "arrayOfArrayOfInt": [ + [1, 2, 3], + [4, 5, 6] + ] + } + } + """); + nested.applyTo(doc); + Document expected = docFromJson(""" + { + "put": "id:unittest:testArrayOfArrayOfInt::whee", + "fields": { + "arrayOfArrayOfInt": [ + [1, 2, 3], + [4, 5, 3] + ] + } + } + """); + assertEquals(expected, doc); + } + + @Test + public void testMatchCannotUpdateNestedFields() { + // Should this work? It doesn't. + assertEquals("Field type Map<string,Array<int>> not supported.", + assertThrows(UnsupportedOperationException.class, + () -> parseUpdate(""" + { + "update": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": { + "match": { + "element": "bamse", + "match": { + "element": 1, + "assign": 4 + } + } + } + } + } + """)).getMessage()); + } + + @Test + public void testMatchCannotAssignToNestedMap() { + // Unsupported value type for map value assign. + assertEquals("Field type Map<string,Array<int>> not supported.", + assertThrows(UnsupportedOperationException.class, + () -> parseUpdate(""" + { + "update": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": { + "match": { + "element": "bamse", + "assign": [1, 3, 4] + } + } + } + } + """)).getMessage()); + } + + @Test + public void testMatchCannotAssignToMap() { + // Unsupported value type for map value assign. + assertEquals("Field type Map<string,string> not supported.", + assertThrows(UnsupportedOperationException.class, + () -> parseUpdate(""" + { + "update": "id:unittest:testmap::whee", + "fields": { + "actualmap": { + "match": { + "element": "bamse", + "assign": "bar" + } + } + } + } + """)).getMessage()); + } + + + + @Test + public void testAssignInsideArrayInMap() throws IOException { + Document doc = docFromJson(""" + { + "put": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": { + "bamse": [1, 2, 3] + } + } + }"""); + + assertEquals(2, ((MapFieldValue<StringFieldValue, Array<IntegerFieldValue>>) doc.getFieldValue("actualMapStringToArrayOfInt")) + .get(StringFieldValue.getFactory().create("bamse")).get(1).getInteger()); + + DocumentUpdate update = parseUpdate(""" + { + "update": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt{bamse}[1]": { + "assign": 4 + } + } + } + """); + assertEquals(1, update.fieldPathUpdates().size()); + + update.applyTo(doc); + assertEquals(4, ((MapFieldValue<StringFieldValue, Array<IntegerFieldValue>>) doc.getFieldValue("actualMapStringToArrayOfInt")) + .get(StringFieldValue.getFactory().create("bamse")).get(1).getInteger()); + } + + @Test public void testAssignToArray() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testMapStringToArrayOfInt::whee',", - " 'fields': {", - " 'actualMapStringToArrayOfInt': {", - " 'assign': { 'bamse': [1, 2, 3] }}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": { + "assign": { + "bamse": [1, 2, 3] + } + } + } + } + """); FieldUpdate f = doc.getFieldUpdate("actualMapStringToArrayOfInt"); assertEquals(1, f.size()); AssignValueUpdate assign = (AssignValueUpdate) f.getValueUpdate(0); @@ -769,11 +1123,21 @@ public class JsonReaderTestCase { @Test public void testOldAssignToArray() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testMapStringToArrayOfInt::whee',", - " 'fields': {", - " 'actualMapStringToArrayOfInt': {", - " 'assign': [", - " { 'key': 'bamse', 'value': [1, 2, 3] } ]}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testMapStringToArrayOfInt::whee", + "fields": { + "actualMapStringToArrayOfInt": { + "assign": [ + { + "key": "bamse", + "value": [1, 2, 3] + } + ] + } + } + } + """); FieldUpdate f = doc.getFieldUpdate("actualMapStringToArrayOfInt"); assertEquals(1, f.size()); AssignValueUpdate assign = (AssignValueUpdate) f.getValueUpdate(0); @@ -787,12 +1151,19 @@ public class JsonReaderTestCase { @Test public void testAssignToWeightedSet() throws IOException { - DocumentUpdate doc = parseUpdate(inputJson("{ 'update': 'id:unittest:testset::whee',", - " 'fields': {", - " 'actualset': {", - " 'assign': {", - " 'person': 37,", - " 'another person': 41 }}}}")); + DocumentUpdate doc = parseUpdate(""" + { + "update": "id:unittest:testset::whee", + "fields": { + "actualset": { + "assign": { + "person": 37, + "another person": 41 + } + } + } + } + """); FieldUpdate x = doc.getFieldUpdate("actualset"); assertEquals(1, x.size()); AssignValueUpdate assign = (AssignValueUpdate) x.getValueUpdate(0); @@ -805,41 +1176,66 @@ public class JsonReaderTestCase { @Test public void testCompleteFeed() { - JsonReader r = createReader(inputJson("[", - "{ 'put': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse' }},", - "{ 'update': 'id:unittest:testarray::whee',", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person' ]}}},", - "{ 'remove': 'id:unittest:smoke::whee' }]")); + JsonReader r = createReader(""" + [ + { + "put": "id:unittest:smoke::whee", + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + } + }, + { + "update": "id:unittest:testarray::whee", + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + }, + { + "remove": "id:unittest:smoke::whee" + } + ] + """); controlBasicFeed(r); } @Test public void testCompleteFeedWithCreateAndCondition() { - JsonReader r = createReader(inputJson("[", - "{ 'put': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse' }},", - "{", - " 'condition':'bla',", - " 'update': 'id:unittest:testarray::whee',", - " 'create':true,", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person' ]}}},", - "{ 'remove': 'id:unittest:smoke::whee' }]")); + JsonReader r = createReader(""" + [ + { + "put": "id:unittest:smoke::whee", + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + } + }, + { + "condition":"bla", + "update": "id:unittest:testarray::whee", + "create":true, + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + }, + { + "remove": "id:unittest:smoke::whee" + } + ] + """); DocumentOperation d = r.next(); Document doc = ((DocumentPut) d).getDocument(); @@ -860,7 +1256,7 @@ public class JsonReaderTestCase { @Test public void testUpdateWithConditionAndCreateInDifferentOrdering() { - int documentsCreated = 106; + int documentsCreated = 106; List<String> parts = Arrays.asList( "\"condition\":\"bla\"", "\"update\": \"id:unittest:testarray::whee\"", @@ -876,8 +1272,7 @@ public class JsonReaderTestCase { } } documents.append("]"); - InputStream rawDoc = new ByteArrayInputStream( - Utf8.toBytes(documents.toString())); + InputStream rawDoc = new ByteArrayInputStream(Utf8.toBytes(documents.toString())); JsonReader r = new JsonReader(types, rawDoc, parserFactory); @@ -886,7 +1281,6 @@ public class JsonReaderTestCase { checkSimpleArrayAdd(update); assertTrue(update.getCreateIfNonExistent()); assertEquals("bla", update.getCondition().getSelection()); - } assertNull(r.next()); @@ -895,13 +1289,18 @@ public class JsonReaderTestCase { @Test public void testCreateIfNonExistentInPut() { - JsonReader r = createReader(inputJson("[{", - " 'create':true,", - " 'fields': {", - " 'something': 'smoketest',", - " 'nalle': 'bamse' },", - " 'put': 'id:unittest:smoke::whee'", - "}]")); + JsonReader r = createReader(""" + [ + { + "create":true, + "fields": { + "something": "smoketest", + "nalle": "bamse" + }, + "put": "id:unittest:smoke::whee" + } + ] + """); var op = r.next(); var put = (DocumentPut) op; assertTrue(put.getCreateIfNonExistent()); @@ -909,23 +1308,32 @@ public class JsonReaderTestCase { @Test public void testCompleteFeedWithIdAfterFields() { - JsonReader r = createReader(inputJson("[", - "{", - " 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse' },", - " 'put': 'id:unittest:smoke::whee'", - "},", - "{", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person' ]}},", - " 'update': 'id:unittest:testarray::whee'", - "},", - "{ 'remove': 'id:unittest:smoke::whee' }]")); + JsonReader r = createReader(""" + [ + { + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + }, + "put": "id:unittest:smoke::whee" + }, + { + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + }, + "update": "id:unittest:testarray::whee" + }, + { + "remove": "id:unittest:smoke::whee" + } + ] + """); controlBasicFeed(r); } @@ -949,10 +1357,21 @@ public class JsonReaderTestCase { @Test public void testCompleteFeedWithEmptyDoc() { - JsonReader r = createReader(inputJson("[", - "{ 'put': 'id:unittest:smoke::whee', 'fields': {} },", - "{ 'update': 'id:unittest:testarray::whee', 'fields': {} },", - "{ 'remove': 'id:unittest:smoke::whee' }]")); + JsonReader r = createReader(""" + [ + { + "put": "id:unittest:smoke::whee", + "fields": {} + }, + { + "update": "id:unittest:testarray::whee", + "fields": {} + }, + { + "remove": "id:unittest:smoke::whee" + } + ] + """); DocumentOperation d = r.next(); Document doc = ((DocumentPut) d).getDocument(); @@ -994,45 +1413,53 @@ public class JsonReaderTestCase { @Test public void nonExistingFieldCausesException() throws IOException { - JsonReader r = createReader(inputJson( - "{ 'put': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'smething': 'smoketest',", - " 'nalle': 'bamse' }}")); - DocumentParseInfo parseInfo = r.parseDocument().get(); - DocumentType docType = r.readDocumentType(parseInfo.documentId); - DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); - - try { - new VespaJsonDocumentReader(false).readPut(parseInfo.fieldsBuffer, put); - fail(); - } catch (IllegalArgumentException e) { - assertTrue(e.getMessage().startsWith("No field 'smething' in the structure of type 'smoke'")); - } + Exception expected = assertThrows(IllegalArgumentException.class, + () -> docFromJson(""" + { + "put": "id:unittest:smoke::whee", + "fields": { + "smething": "smoketest", + "nalle": "bamse" + } + } + """)); + assertTrue(expected.getMessage().startsWith("No field 'smething' in the structure of type 'smoke'")); } @Test public void nonExistingFieldsCanBeIgnoredInPut() throws IOException { - JsonReader r = createReader(inputJson( - "{ ", - " 'put': 'id:unittest:smoke::doc1',", - " 'fields': {", - " 'nonexisting1': 'ignored value',", - " 'field1': 'value1',", - " 'nonexisting2': {", - " 'blocks':{", - " 'a':[2.0,3.0],", - " 'b':[4.0,5.0]", - " }", - " },", - " 'field2': 'value2',", - " 'nonexisting3': {", - " 'cells': [{'address': {'x': 'x1'}, 'value': 1.0}]", - " },", - " 'tensor1': {'cells': {'x1': 1.0}},", - " 'nonexisting4': 'ignored value'", - " }", - "}")); + JsonReader r = createReader(""" + { + "put": "id:unittest:smoke::doc1", + "fields": { + "nonexisting1": "ignored value", + "field1": "value1", + "nonexisting2": { + "blocks": { + "a": [2.0, 3.0], + "b": [4.0, 5.0] + } + }, + "field2": "value2", + "nonexisting3": { + "cells": [ + { + "address": { + "x": "x1" + }, + "value": 1.0 + } + ] + }, + "tensor1": { + "cells": { + "x1": 1.0 + } + }, + "nonexisting4": "ignored value" + } + } + """); DocumentParseInfo parseInfo = r.parseDocument().get(); DocumentType docType = r.readDocumentType(parseInfo.documentId); DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); @@ -1049,30 +1476,31 @@ public class JsonReaderTestCase { @Test public void nonExistingFieldsCanBeIgnoredInUpdate() throws IOException{ - JsonReader r = createReader(inputJson( - "{ ", - " 'update': 'id:unittest:smoke::doc1',", - " 'fields': {", - " 'nonexisting1': { 'assign': 'ignored value' },", - " 'field1': { 'assign': 'value1' },", - " 'nonexisting2': { " + - " 'assign': {", - " 'blocks': {", - " 'a':[2.0,3.0],", - " 'b':[4.0,5.0]", - " }", - " }", - " },", - " 'field2': { 'assign': 'value2' },", - " 'nonexisting3': {", - " 'assign' : {", - " 'cells': [{'address': {'x': 'x1'}, 'value': 1.0}]", - " }", - " },", - " 'tensor1': {'assign': { 'cells': {'x1': 1.0} } },", - " 'nonexisting4': { 'assign': 'ignored value' }", - " }", - "}")); + JsonReader r = createReader(""" + { + "update": "id:unittest:smoke::doc1", + "fields": { + "nonexisting1": { "assign": "ignored value" }, + "field1": { "assign": "value1" }, + "nonexisting2": { + "assign": { + "blocks": { + "a":[2.0,3.0], + "b":[4.0,5.0] + } + } + }, + "field2": { "assign": "value2" }, + "nonexisting3": { + "assign" : { + "cells": [{"address": {"x": "x1"}, "value": 1.0}] + } + }, + "tensor1": {"assign": { "cells": {"x1": 1.0} } }, + "nonexisting4": { "assign": "ignored value" } + } + } + """); DocumentParseInfo parseInfo = r.parseDocument().get(); DocumentType docType = r.readDocumentType(parseInfo.documentId); DocumentUpdate update = new DocumentUpdate(docType, parseInfo.documentId); @@ -1089,26 +1517,44 @@ public class JsonReaderTestCase { @Test public void feedWithBasicErrorTest() { - JsonReader r = createReader(inputJson("[", - " { 'put': 'id:test:smoke::0', 'fields': { 'something': 'foo' } },", - " { 'put': 'id:test:smoke::1', 'fields': { 'something': 'foo' } },", - " { 'put': 'id:test:smoke::2', 'fields': { 'something': 'foo' } },", - "]")); - try { - while (r.next() != null) ; - fail(); - } catch (RuntimeException e) { - assertTrue(e.getMessage().contains("JsonParseException")); - } + JsonReader r = createReader(""" + [ + { + "put": "id:test:smoke::0", + "fields": { + "something": "foo" + } + }, + { + "put": "id:test:smoke::1", + "fields": { + "something": "foo" + } + }, + { + "put": "id:test:smoke::2", + "fields": { + "something": "foo" + } + }, + ]"""); // Trailing comma in array ... + assertTrue(assertThrows(RuntimeException.class, + () -> { while (r.next() != null); }) + .getMessage().contains("JsonParseException")); } @Test public void idAsAliasForPutTest() throws IOException{ - JsonReader r = createReader(inputJson("{ 'id': 'id:unittest:smoke::doc1',", - " 'fields': {", - " 'something': 'smoketest',", - " 'flag': true,", - " 'nalle': 'bamse' }}")); + JsonReader r = createReader(""" + { + "id": "id:unittest:smoke::doc1", + "fields": { + "something": "smoketest", + "flag": true, + "nalle": "bamse" + } + } + """); DocumentParseInfo parseInfo = r.parseDocument().get(); DocumentType docType = r.readDocumentType(parseInfo.documentId); DocumentPut put = new DocumentPut(new Document(docType, parseInfo.documentId)); @@ -1138,147 +1584,146 @@ public class JsonReaderTestCase { @Test public void testFeedWithTestAndSetConditionOrderingOne() { - testFeedWithTestAndSetCondition( - inputJson("[", - " {", - " 'put': 'id:unittest:smoke::whee',", - " 'condition': 'smoke.something == \\'smoketest\\'',", - " 'fields': {", - " 'something': 'smoketest',", - " 'nalle': 'bamse'", - " }", - " },", - " {", - " 'update': 'id:unittest:testarray::whee',", - " 'condition': 'smoke.something == \\'smoketest\\'',", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person'", - " ]", - " }", - " }", - " },", - " {", - " 'remove': 'id:unittest:smoke::whee',", - " 'condition': 'smoke.something == \\'smoketest\\''", - " }", - "]" - )); + testFeedWithTestAndSetCondition(""" + [ + { + "put": "id:unittest:smoke::whee", + "condition": "smoke.something == \\"smoketest\\"", + "fields": { + "something": "smoketest", + "nalle": "bamse" + } + }, + { + "update": "id:unittest:testarray::whee", + "condition": "smoke.something == \\"smoketest\\"", + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + }, + { + "remove": "id:unittest:smoke::whee", + "condition": "smoke.something == \\"smoketest\\"" + } + ] + """); } @Test public void testFeedWithTestAndSetConditionOrderingTwo() { - testFeedWithTestAndSetCondition( - inputJson("[", - " {", - " 'condition': 'smoke.something == \\'smoketest\\'',", - " 'put': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': 'smoketest',", - " 'nalle': 'bamse'", - " }", - " },", - " {", - " 'condition': 'smoke.something == \\'smoketest\\'',", - " 'update': 'id:unittest:testarray::whee',", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person'", - " ]", - " }", - " }", - " },", - " {", - " 'condition': 'smoke.something == \\'smoketest\\'',", - " 'remove': 'id:unittest:smoke::whee'", - " }", - "]" - )); + testFeedWithTestAndSetCondition(""" + [ + { + "condition": "smoke.something == \\"smoketest\\"", + "put": "id:unittest:smoke::whee", + "fields": { + "something": "smoketest", + "nalle": "bamse" + } + }, + { + "condition": "smoke.something == \\"smoketest\\"", + "update": "id:unittest:testarray::whee", + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + }, + { + "condition": "smoke.something == \\"smoketest\\"", + "remove": "id:unittest:smoke::whee" + } + ] + """); } @Test public void testFeedWithTestAndSetConditionOrderingThree() { - testFeedWithTestAndSetCondition( - inputJson("[", - " {", - " 'put': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': 'smoketest',", - " 'nalle': 'bamse'", - " },", - " 'condition': 'smoke.something == \\'smoketest\\''", - " },", - " {", - " 'update': 'id:unittest:testarray::whee',", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person'", - " ]", - " }", - " },", - " 'condition': 'smoke.something == \\'smoketest\\''", - " },", - " {", - " 'remove': 'id:unittest:smoke::whee',", - " 'condition': 'smoke.something == \\'smoketest\\''", - " }", - "]" - )); + testFeedWithTestAndSetCondition(""" + [ + { + "put": "id:unittest:smoke::whee", + "fields": { + "something": "smoketest", + "nalle": "bamse" + }, + "condition": "smoke.something == \\"smoketest\\"" + }, + { + "update": "id:unittest:testarray::whee", + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + }, + "condition": "smoke.something == \\"smoketest\\"" + }, + { + "remove": "id:unittest:smoke::whee", + "condition": "smoke.something == \\"smoketest\\"" + } + ] + """); } @Test(expected = IllegalArgumentException.class) public void testInvalidFieldAfterFieldsFieldShouldFailParse() { - final String jsonData = inputJson( - "[", - " {", - " 'put': 'id:unittest:smoke::whee',", - " 'fields': {", - " 'something': 'smoketest',", - " 'nalle': 'bamse'", - " },", - " 'bjarne': 'stroustrup'", - " }", - "]"); + String jsonData = """ + [ + { + "put": "id:unittest:smoke::whee", + "fields": { + "something": "smoketest", + "nalle": "bamse" + }, + "bjarne": "stroustrup" + } + ]"""; new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); } @Test(expected = IllegalArgumentException.class) public void testInvalidFieldBeforeFieldsFieldShouldFailParse() { - final String jsonData = inputJson( - "[", - " {", - " 'update': 'id:unittest:testarray::whee',", - " 'what is this': 'nothing to see here',", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person'", - " ]", - " }", - " }", - " }", - "]"); - + String jsonData = """ + [ + { + "update": "id:unittest:testarray::whee", + "what is this": "nothing to see here", + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + } + ]"""; new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); } @Test(expected = IllegalArgumentException.class) public void testInvalidFieldWithoutFieldsFieldShouldFailParse() { - String jsonData = inputJson( - "[", - " {", - " 'remove': 'id:unittest:smoke::whee',", - " 'what is love': 'baby, do not hurt me... much'", - " }", - "]"); + String jsonData = """ + [ + { + "remove": "id:unittest:smoke::whee", + "what is love": "baby, do not hurt me... much + } + ]"""; new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); } @@ -1286,19 +1731,19 @@ public class JsonReaderTestCase { @Test public void testMissingOperation() { try { - String jsonData = inputJson( - "[", - " {", - " 'fields': {", - " 'actualarray': {", - " 'add': [", - " 'person',", - " 'another person'", - " ]", - " }", - " }", - " }", - "]"); + String jsonData = """ + [ + { + "fields": { + "actualarray": { + "add": [ + "person", + "another person" + ] + } + } + } + ]"""; new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); fail("Expected exception"); @@ -1311,12 +1756,12 @@ public class JsonReaderTestCase { @Test public void testMissingFieldsMapInPut() { try { - String jsonData = inputJson( - "[", - " {", - " 'put': 'id:unittest:smoke::whee'", - " }", - "]"); + String jsonData = """ + [ + { + "put": "id:unittest:smoke::whee" + } + ]"""; new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); fail("Expected exception"); @@ -1329,12 +1774,12 @@ public class JsonReaderTestCase { @Test public void testMissingFieldsMapInUpdate() { try { - String jsonData = inputJson( - "[", - " {", - " 'update': 'id:unittest:smoke::whee'", - " }", - "]"); + String jsonData = """ + [ + { + "update": "id:unittest:smoke::whee" + } + ]"""; new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); fail("Expected exception"); @@ -1345,20 +1790,20 @@ public class JsonReaderTestCase { } @Test - public void testNullValues() { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:testnull::doc1',", - " 'fields': {", - " 'intfield': null,", - " 'stringfield': null,", - " 'arrayfield': null,", - " 'weightedsetfield': null,", - " 'mapfield': null,", - " 'tensorfield': null", - " }", - "}")); - DocumentPut put = (DocumentPut) r.readSingleDocument(DocumentOperationType.PUT, - "id:unittest:testnull::doc1").operation(); - Document doc = put.getDocument(); + public void testNullValues() throws IOException { + Document doc = docFromJson(""" + { + "put": "id:unittest:testnull::doc1", + "fields": { + "intfield": null, + "stringfield": null, + "arrayfield": null, + "weightedsetfield": null, + "mapfield": null, + "tensorfield": null + } + } + """); assertFieldValueNull(doc, "intfield"); assertFieldValueNull(doc, "stringfield"); assertFieldValueNull(doc, "arrayfield"); @@ -1368,13 +1813,15 @@ public class JsonReaderTestCase { } @Test(expected=JsonReaderException.class) - public void testNullArrayElement() { - JsonReader r = createReader(inputJson("{ 'put': 'id:unittest:testnull::doc1',", - " 'fields': {", - " 'arrayfield': [ null ]", - " }", - "}")); - r.readSingleDocument(DocumentOperationType.PUT, "id:unittest:testnull::doc1"); + public void testNullArrayElement() throws IOException { + docFromJson(""" + { + "put": "id:unittest:testnull::doc1", + "fields": { + "arrayfield": [ null ] + } + } + """); fail(); } @@ -1429,30 +1876,31 @@ public class JsonReaderTestCase { @Test public void testParsingOfSparseTensorWithCells() { Tensor tensor = assertSparseTensorField("{{x:a,y:b}:2.0,{x:c,y:b}:3.0}}", - createPutWithSparseTensor( - """ - { - "type": "tensor(x{},y{})", - "cells": [ - { "address": { "x": "a", "y": "b" }, "value": 2.0 }, - { "address": { "x": "c", "y": "b" }, "value": 3.0 } - ] - } - """)); + createPutWithSparseTensor(""" + { + "type": "tensor(x{},y{})", + "cells": [ + { "address": { "x": "a", "y": "b" }, "value": 2.0 }, + { "address": { "x": "c", "y": "b" }, "value": 3.0 } + ] + } + """)); assertTrue(tensor instanceof MappedTensor); // any functional instance is fine } @Test public void testParsingOfDenseTensorWithCells() { Tensor tensor = assertTensorField("{{x:0,y:0}:2.0,{x:1,y:0}:3.0}}", - createPutWithTensor(inputJson("{", - " 'cells': [", - " { 'address': { 'x': '0', 'y': '0' },", - " 'value': 2.0 },", - " { 'address': { 'x': '1', 'y': '0' },", - " 'value': 3.0 }", - " ]", - "}"), "dense_unbound_tensor"), "dense_unbound_tensor"); + createPutWithTensor(""" + { + "cells": [ + { "address": { "x": 0, "y": 0 }, "value": 2.0 }, + { "address": { "x": 1, "y": 0 }, "value": 3.0 } + ] + } + """, + "dense_unbound_tensor"), + "dense_unbound_tensor"); assertTrue(tensor instanceof IndexedTensor); // this matters for performance } @@ -1468,9 +1916,10 @@ public class JsonReaderTestCase { Tensor expected = builder.build(); Tensor tensor = assertTensorField(expected, - createPutWithTensor(inputJson("{", - " 'values': [2.0, 3.0, 4.0, 'inf', 6.0, 7.0]", - "}"), "dense_tensor"), "dense_tensor"); + createPutWithTensor(""" + { + "values": [2.0, 3.0, 4.0, "inf", 6.0, 7.0] + }""", "dense_tensor"), "dense_tensor"); assertTrue(tensor instanceof IndexedTensor); // this matters for performance } @@ -1485,9 +1934,10 @@ public class JsonReaderTestCase { builder.cell().label("x", 1).label("y", 2).value(7.0); Tensor expected = builder.build(); Tensor tensor = assertTensorField(expected, - createPutWithTensor(inputJson("{", - " 'values': \"020304050607\"", - "}"), "dense_int8_tensor"), "dense_int8_tensor"); + createPutWithTensor(""" + { + "values": "020304050607" + }""", "dense_int8_tensor"), "dense_int8_tensor"); assertTrue(tensor instanceof IndexedTensor); // this matters for performance } @@ -1501,10 +1951,14 @@ public class JsonReaderTestCase { builder.cell().label("x", "bar").label("y", 1).value(6.0); builder.cell().label("x", "bar").label("y", 2).value(7.0); Tensor expected = builder.build(); - String mixedJson = "{\"blocks\":[" + - "{\"address\":{\"x\":\"foo\"},\"values\":\"400040404080\"}," + - "{\"address\":{\"x\":\"bar\"},\"values\":\"40A040C040E0\"}" + - "]}"; + String mixedJson = """ + { + "blocks":[ + {"address":{"x":"foo"},"values":"400040404080"}, + {"address":{"x":"bar"},"values":"40A040C040E0"} + ] + } + """; var put = createPutWithTensor(inputJson(mixedJson), "mixed_bfloat16_tensor"); Tensor tensor = assertTensorField(expected, put, "mixed_bfloat16_tensor"); } @@ -1587,10 +2041,14 @@ public class JsonReaderTestCase { builder.cell().label("x", 1).label("y", 2).value(7.0); Tensor expected = builder.build(); - String mixedJson = "{\"blocks\":{" + - "\"0\":[2.0,3.0,4.0]," + - "\"1\":[5.0,6.0,7.0]" + - "}}"; + String mixedJson = """ + { + "blocks":{ + "0":[2.0,3.0,4.0], + "1":[5.0,6.0,7.0] + } + } + """; Tensor tensor = assertTensorField(expected, createPutWithTensor(inputJson(mixedJson), "mixed_tensor"), "mixed_tensor"); assertTrue(tensor instanceof MixedTensor); // this matters for performance @@ -1599,12 +2057,14 @@ public class JsonReaderTestCase { @Test public void testParsingOfTensorWithSingleCellInDifferentJsonOrder() { assertSparseTensorField("{{x:a,y:b}:2.0}", - createPutWithSparseTensor(inputJson("{", - " 'cells': [", - " { 'value': 2.0,", - " 'address': { 'x': 'a', 'y': 'b' } }", - " ]", - "}"))); + createPutWithSparseTensor(""" + { + "cells": [ + { "value": 2.0, + "address": { "x": "a", "y": "b" } } + ] + } + """)); } @Test @@ -1634,23 +2094,27 @@ public class JsonReaderTestCase { @Test public void testAssignUpdateOfTensorWithCells() { assertTensorAssignUpdateSparseField("{{x:a,y:b}:2.0,{x:c,y:b}:3.0}}", - createAssignUpdateWithSparseTensor(inputJson("{", - " 'cells': [", - " { 'address': { 'x': 'a', 'y': 'b' },", - " 'value': 2.0 },", - " { 'address': { 'x': 'c', 'y': 'b' },", - " 'value': 3.0 }", - " ]", - "}"))); + createAssignUpdateWithSparseTensor(""" + { + "cells": [ + { "address": { "x": "a", "y": "b" }, + "value": 2.0 }, + { "address": { "x": "c", "y": "b" }, + "value": 3.0 } + ] + } + """)); } @Test public void testAssignUpdateOfTensorDenseShortForm() { assertTensorAssignUpdateDenseField("tensor(x[2],y[3]):[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]", - createAssignUpdateWithTensor(inputJson("{", - " 'values': [1,2,3,4,5,6]", - "}"), - "dense_tensor")); + createAssignUpdateWithTensor(""" + { + "values": [1,2,3,4,5,6] + } + """, + "dense_tensor")); } @Test diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java index 2854ef8836a..72f0fb977d5 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java @@ -74,7 +74,17 @@ public class FileDownloader implements AutoCloseable { } } - Future<Optional<File>> getFutureFile(FileReferenceDownload fileReferenceDownload) { + /** Returns a future that times out if download takes too long, and return empty on unsuccessful download. */ + public Future<Optional<File>> getFutureFileOrTimeout(FileReferenceDownload fileReferenceDownload) { + return getFutureFile(fileReferenceDownload) + .orTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS) + .exceptionally(thrown -> { + fileReferenceDownloader.failedDownloading(fileReferenceDownload.fileReference()); + return Optional.empty(); + }); + } + + CompletableFuture<Optional<File>> getFutureFile(FileReferenceDownload fileReferenceDownload) { FileReference fileReference = fileReferenceDownload.fileReference(); Optional<File> file = getFileFromFileSystem(fileReference); @@ -135,7 +145,7 @@ public class FileDownloader implements AutoCloseable { } /** Start downloading, the future returned will be complete()d by receiving method in {@link FileReceiver} */ - private synchronized Future<Optional<File>> startDownload(FileReferenceDownload fileReferenceDownload) { + private synchronized CompletableFuture<Optional<File>> startDownload(FileReferenceDownload fileReferenceDownload) { return fileReferenceDownloader.startDownload(fileReferenceDownload); } diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java index 450801ce530..5ad197e8633 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java @@ -15,6 +15,7 @@ import java.time.Instant; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -67,7 +68,7 @@ public class FileReferenceDownloader { int retryCount = 0; Connection connection = connectionPool.getCurrent(); do { - backoff(retryCount); + backoff(retryCount, end); if (FileDownloader.fileReferenceExists(fileReference, downloadDirectory)) return; @@ -79,24 +80,26 @@ public class FileReferenceDownloader { // exist on just one config server, and which one could be different for each file reference), so // switch to a new connection for every retry connection = connectionPool.switchConnection(connection); - } while (retryCount < 5 || Instant.now().isAfter(end)); + } while (Instant.now().isBefore(end)); fileReferenceDownload.future().completeExceptionally(new RuntimeException("Failed getting " + fileReference)); downloads.remove(fileReference); } - private void backoff(int retryCount) { + private void backoff(int retryCount, Instant end) { if (retryCount > 0) { try { - long sleepTime = Math.min(120_000, (long) (Math.pow(2, retryCount)) * sleepBetweenRetries.toMillis()); - Thread.sleep(sleepTime); + long sleepTime = Math.min(120_000, + Math.min((long) (Math.pow(2, retryCount)) * sleepBetweenRetries.toMillis(), + Duration.between(Instant.now(), end).toMillis())); + if (sleepTime > 0) Thread.sleep(sleepTime); } catch (InterruptedException e) { /* ignored */ } } } - Future<Optional<File>> startDownload(FileReferenceDownload fileReferenceDownload) { + CompletableFuture<Optional<File>> startDownload(FileReferenceDownload fileReferenceDownload) { FileReference fileReference = fileReferenceDownload.fileReference(); Optional<FileReferenceDownload> inProgress = downloads.get(fileReference); if (inProgress.isPresent()) return inProgress.get().future(); diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index 48afe9bd481..4aac29c5093 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -421,25 +421,6 @@ public class Flags { "Whether to send cloud trial email notifications", "Takes effect immediately"); - public static final UnboundLongFlag MERGING_MAX_MEMORY_USAGE_PER_NODE = defineLongFlag( - "merging-max-memory-usage-per-node", -1, - List.of("vekterli"), "2023-11-03", "2024-03-01", - "Soft limit of the maximum amount of memory that can be used across merge operations on a content node. " + - "Value semantics: < 0: unlimited (legacy behavior), == 0: auto-deduced from node HW and config," + - " > 0: explicit memory usage limit in bytes.", - "Takes effect at redeployment", - INSTANCE_ID); - - public static final UnboundBooleanFlag USE_PER_DOCUMENT_THROTTLED_DELETE_BUCKET = defineFeatureFlag( - "use-per-document-throttled-delete-bucket", false, - List.of("vekterli"), "2023-11-13", "2024-03-01", - "If set, DeleteBucket operations are internally expanded to an individually persistence-" + - "throttled remove per document stored in the bucket. This makes the cost model of " + - "executing a DeleteBucket symmetrical with feeding the documents to the bucket in the " + - "first place.", - "Takes effect at redeployment", - INSTANCE_ID); - public static final UnboundBooleanFlag ENABLE_NEW_PAYMENT_METHOD_FLOW = defineFeatureFlag( "enable-new-payment-method-flow", false, List.of("bjorncs"), "2023-11-29", "2024-03-01", diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index 1ffb879e57e..dc6a62cc463 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -344,7 +344,9 @@ "public java.lang.String getDestination()", "public com.yahoo.language.process.Embedder$Context setDestination(java.lang.String)", "public java.lang.String getEmbedderId()", - "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)" + "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)", + "public java.util.Map getContextValues()", + "public com.yahoo.language.process.Embedder$Context setContextValues(java.util.Map)" ], "fields" : [ ] }, diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java index fa141977d5d..f6fdb86d01f 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java @@ -88,6 +88,7 @@ public interface Embedder { private Language language = Language.UNKNOWN; private String destination; private String embedderId = "unknown"; + private Map<String, String> contextValues; public Context(String destination) { this.destination = destination; @@ -138,6 +139,15 @@ public interface Embedder { this.embedderId = embedderId; return this; } + + /** Returns a read-only map of context key-values which can be looked up during conversion. */ + public Map<String, String> getContextValues() { return contextValues; } + + public Context setContextValues(Map<String, String> contextValues) { + this.contextValues = contextValues; + return this; + } + } class FailingEmbedder implements Embedder { diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 853009873a1..28f8c4e252f 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -10,9 +10,9 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.embedding.SpladeEmbedderConfig; import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.DirectIndexedAddress; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.nio.file.Paths; import java.util.List; @@ -152,10 +152,15 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { String dimension = tensorType.dimensions().get(0).name(); //Iterate over the vocab dimension and find the max value for each sequence token long [] tokens = new long[1]; + DirectIndexedAddress directAddress = modelOutput.directAddress(); + directAddress.setIndex(0,0); for (int v = 0; v < vocabSize; v++) { double maxValue = 0.0d; + directAddress.setIndex(2, v); + long increment = directAddress.getStride(1); + long directIndex = directAddress.getDirectIndex(); for (int s = 0; s < sequenceLength; s++) { - double value = modelOutput.get(0, s, v); // batch, sequence, vocab + double value = modelOutput.get(directIndex + s * increment); if (value > maxValue) { maxValue = value; } diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index 07f2aea4ab6..d1a06d8c7ff 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -28,6 +28,7 @@ import java.nio.ShortBuffer; import java.util.HashMap; import java.util.Map; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; @@ -101,70 +102,67 @@ class TensorConverter { throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type); } + interface Short2Float { + float convert(short value); + } + + private static void extractTensor(FloatBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(DoubleBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ByteBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ShortBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ShortBuffer buffer, Short2Float converter, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, converter.convert(buffer.get(i))); + } + private static void extractTensor(IntBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(LongBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + static Tensor toVespaTensor(OnnxValue onnxValue) { if ( ! (onnxValue instanceof OnnxTensor onnxTensor)) { throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported"); } TensorInfo tensorInfo = onnxTensor.getInfo(); - TensorType type = toVespaType(onnxTensor.getInfo()); - DimensionSizes sizes = sizesFromType(type); - + DimensionSizes sizes = DimensionSizes.of(type); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type, sizes); - long totalSize = sizes.totalSize(); - if (tensorInfo.type == OnnxJavaType.FLOAT) { - FloatBuffer buffer = onnxTensor.getFloatBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.DOUBLE) { - DoubleBuffer buffer = onnxTensor.getDoubleBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT8) { - ByteBuffer buffer = onnxTensor.getByteBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT32) { - IntBuffer buffer = onnxTensor.getIntBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT64) { - LongBuffer buffer = onnxTensor.getLongBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.FLOAT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get())); - } - else if (tensorInfo.type == OnnxJavaType.BFLOAT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get()))); - } - else { - throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); + long totalSizeAsLong = sizes.totalSize(); + if (totalSizeAsLong > Integer.MAX_VALUE) { + throw new IllegalArgumentException("TotalSize=" + totalSizeAsLong + " currently limited at INTEGER.MAX_VALUE"); + } + + int totalSize = (int) totalSizeAsLong; + switch (tensorInfo.type) { + case FLOAT -> extractTensor(onnxTensor.getFloatBuffer(), builder, totalSize); + case DOUBLE -> extractTensor(onnxTensor.getDoubleBuffer(), builder, totalSize); + case INT8 -> extractTensor(onnxTensor.getByteBuffer(), builder, totalSize); + case INT16 -> extractTensor(onnxTensor.getShortBuffer(), builder, totalSize); + case INT32 -> extractTensor(onnxTensor.getIntBuffer(), builder, totalSize); + case INT64 -> extractTensor(onnxTensor.getLongBuffer(), builder, totalSize); + case FLOAT16 -> extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::fp16ToFloat, builder, totalSize); + case BFLOAT16 -> extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::bf16ToFloat, builder, totalSize); + default -> throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); } return builder.build(); } - static private DimensionSizes sizesFromType(TensorType type) { - DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); - for (int i = 0; i < type.dimensions().size(); i++) - builder.set(i, type.dimensions().get(i).size().get()); - return builder.build(); - } - static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) { return infoMap.entrySet().stream().collect(Collectors.toMap(e -> asValidName(e.getKey()), e -> toVespaType(e.getValue().getInfo()))); diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java index 9ecb0e3e162..b48051814ab 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java @@ -49,11 +49,11 @@ public class SpladeEmbedderTest { String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" + " ii the project was led by the united states with the support of the united kingdom and canada"; Long now = System.currentTimeMillis(); - int n = 10; + int n = 1000; // Takes around 7s on Intel core i9 2.4Ghz (macbook pro, 2019) for (int i = 0; i < n; i++) { assertEmbed("tensor<float>(t{})", text, indexingContext); } - Long elapsed = (System.currentTimeMillis() - now)/1000; + Long elapsed = System.currentTimeMillis() - now; System.out.println("Elapsed time: " + elapsed + " ms"); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java index e6d0c339e5a..2bec9aa6115 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java @@ -158,7 +158,15 @@ public class AutoscalingMaintainer extends NodeRepositoryMaintainer { log.info("Autoscaling " + application + " " + clusterNodes.clusterSpec() + ":" + "\nfrom " + toString(from) + "\nto " + toString(to)); metric.add(ConfigServerMetrics.CLUSTER_AUTOSCALED.baseName(), 1, - metric.createContext(MetricsReporter.dimensions(application, clusterNodes.clusterSpec().id()))); + metric.createContext(dimensions(application, clusterNodes.clusterSpec()))); + } + + private static Map<String, String> dimensions(ApplicationId application, ClusterSpec clusterSpec) { + return Map.of("tenantName", application.tenant().value(), + "applicationId", application.serializedForm().replace(':', '.'), + "app", application.application().value() + "." + application.instance().value(), + "clusterid", clusterSpec.id().value(), + "clustertype", clusterSpec.type().name()); } static String toString(ClusterResources r) { diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java index b644e5d8a08..b82d1809085 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java @@ -408,7 +408,7 @@ public class MetricsReporter extends NodeRepositoryMaintainer { metric.set(ConfigServerMetrics.NODES_EMPTY_EXCLUSIVE.baseName(), emptyHosts, null); } - public static Map<String, String> dimensions(ApplicationId application, ClusterSpec.Id cluster) { + static Map<String, String> dimensions(ApplicationId application, ClusterSpec.Id cluster) { Map<String, String> dimensions = new HashMap<>(dimensions(application)); dimensions.put("clusterid", cluster.value()); return dimensions; diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingOsUpgrader.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingOsUpgrader.java index 5d8296d6f9d..bd7f39b1f6e 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingOsUpgrader.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingOsUpgrader.java @@ -38,7 +38,7 @@ public class DelegatingOsUpgrader extends OsUpgrader { .matching(node -> canUpgradeTo(target.version(), now, node)) .byIncreasingOsVersion() .first(upgradeSlots(target, activeNodes)); - if (nodesToUpgrade.size() == 0) return; + if (nodesToUpgrade.isEmpty()) return; LOG.info("Upgrading " + nodesToUpgrade.size() + " nodes of type " + target.nodeType() + " to OS version " + target.version().toFullString()); nodeRepository.nodes().upgradeOs(NodeListFilter.from(nodesToUpgrade.asList()), Optional.of(target.version())); @@ -49,7 +49,7 @@ public class DelegatingOsUpgrader extends OsUpgrader { NodeList nodesUpgrading = nodeRepository.nodes().list() .nodeType(type) .changingOsVersion(); - if (nodesUpgrading.size() == 0) return; + if (nodesUpgrading.isEmpty()) return; LOG.info("Disabling OS upgrade of all " + type + " nodes"); nodeRepository.nodes().upgradeOs(NodeListFilter.from(nodesUpgrading.asList()), Optional.empty()); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/OsUpgrader.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/OsUpgrader.java index f56e75518a3..2f09a0a5a29 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/OsUpgrader.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/OsUpgrader.java @@ -44,7 +44,7 @@ public abstract class OsUpgrader { /** Returns the number of upgrade slots available for given target */ final int upgradeSlots(OsVersionTarget target, NodeList candidates) { if (!candidates.stream().allMatch(node -> node.type() == target.nodeType())) { - throw new IllegalArgumentException("All node types must type of OS version target " + target.nodeType()); + throw new IllegalArgumentException("All node types must match type of OS version target " + target.nodeType()); } int max = target.nodeType() == NodeType.host ? maxActiveUpgrades.value() : 1; int upgrading = candidates.changingOsVersionTo(target.version()).size(); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringOsUpgrader.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringOsUpgrader.java index cb6c7683f23..a5ff7b82551 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringOsUpgrader.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringOsUpgrader.java @@ -2,21 +2,32 @@ package com.yahoo.vespa.hosted.provision.os; import com.yahoo.component.Version; +import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.NodeType; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.node.Agent; +import com.yahoo.vespa.hosted.provision.node.ClusterId; import com.yahoo.vespa.hosted.provision.node.filter.NodeListFilter; import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.logging.Logger; +import java.util.stream.Collectors; /** - * An upgrader that retires and deprovisions hosts on stale OS versions. + * An upgrader that retires and deprovisions hosts on stale OS versions. For hosts containing stateful clusters, this + * upgrader limits node retirement so that at most one group per cluster is affected at a time. * - * Used in clouds where hosts must be re-provisioned to upgrade their OS. + * Used in clouds where the host configuration (e.g. local disk) requires re-provisioning to upgrade OS. * * @author mpolden */ @@ -35,8 +46,8 @@ public class RetiringOsUpgrader extends OsUpgrader { public void upgradeTo(OsVersionTarget target) { NodeList allNodes = nodeRepository.nodes().list(); Instant now = nodeRepository.clock().instant(); - for (var candidate : candidates(now, target, allNodes)) { - deprovision(candidate, target.version(), now); + for (Node host : deprovisionable(now, target, allNodes)) { + deprovision(host, target.version(), now); } } @@ -45,18 +56,46 @@ public class RetiringOsUpgrader extends OsUpgrader { // No action needed in this implementation. } - /** Returns nodes that are candidates for upgrade */ - private NodeList candidates(Instant instant, OsVersionTarget target, NodeList allNodes) { + /** Returns nodes that can be deprovisioned at given instant */ + private List<Node> deprovisionable(Instant instant, OsVersionTarget target, NodeList allNodes) { NodeList nodes = allNodes.state(Node.State.active, Node.State.provisioned).nodeType(target.nodeType()); if (softRebuild) { - // Retire only hosts which do not have a replaceable root disk + // Consider only hosts which do not have a replaceable root disk nodes = nodes.not().replaceableRootDisk(); } - return nodes.not().deprovisioning() - .not().onOsVersion(target.version()) - .matching(node -> canUpgradeTo(target.version(), instant, node)) - .byIncreasingOsVersion() - .first(upgradeSlots(target, nodes.deprovisioning())); + // Retire hosts up to slot limit while ensuring that only one group is retired at a time + NodeList activeNodes = allNodes.state(Node.State.active); + Map<ClusterId, Set<ClusterSpec.Group>> retiringGroupsByCluster = groupsOf(activeNodes.retiring()); + int limit = upgradeSlots(target, nodes.deprovisioning()); + List<Node> result = new ArrayList<>(); + NodeList candidates = nodes.not().deprovisioning() + .not().onOsVersion(target.version()) + .matching(node -> canUpgradeTo(target.version(), instant, node)) + .byIncreasingOsVersion(); + for (Node host : candidates) { + if (result.size() == limit) break; + // For all clusters residing on this host: Determine if deprovisioning the host would imply retiring nodes + // in additional groups beyond those already having retired nodes. If true, defer deprovisioning the host + boolean canDeprovision = true; + Map<ClusterId, Set<ClusterSpec.Group>> groupsOnHost = groupsOf(activeNodes.childrenOf(host)); + for (var clusterAndGroups : groupsOnHost.entrySet()) { + Set<ClusterSpec.Group> groups = clusterAndGroups.getValue(); + Set<ClusterSpec.Group> retiringGroups = retiringGroupsByCluster.get(clusterAndGroups.getKey()); + if (retiringGroups != null && !groups.equals(retiringGroups)) { + canDeprovision = false; + break; + } + } + // Deprovision host and count all cluster groups on the host as being retired + if (canDeprovision) { + result.add(host); + groupsOnHost.forEach((cluster, groups) -> retiringGroupsByCluster.merge(cluster, groups, (oldVal, newVal) -> { + oldVal.addAll(newVal); + return oldVal; + })); + } + } + return Collections.unmodifiableList(result); } /** Upgrade given host by retiring and deprovisioning it */ @@ -68,4 +107,17 @@ public class RetiringOsUpgrader extends OsUpgrader { nodeRepository.nodes().upgradeOs(NodeListFilter.from(host), Optional.of(target)); } + /** Returns the stateful groups present on given nodes, grouped by their cluster ID */ + private static Map<ClusterId, Set<ClusterSpec.Group>> groupsOf(NodeList nodes) { + return nodes.stream() + .filter(node -> node.allocation().isPresent() && + node.allocation().get().membership().cluster().isStateful() && + node.allocation().get().membership().cluster().group().isPresent()) + .collect(Collectors.groupingBy(node -> new ClusterId(node.allocation().get().owner(), + node.allocation().get().membership().cluster().id()), + HashMap::new, + Collectors.mapping(n -> n.allocation().get().membership().cluster().group().get(), + Collectors.toCollection(HashSet::new)))); + } + } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/ApplicationSerializer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/ApplicationSerializer.java index 89853896104..5c788731386 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/ApplicationSerializer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/ApplicationSerializer.java @@ -1,10 +1,14 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.provision.restapi; -import com.yahoo.config.provision.IntRange; +import com.yahoo.component.Version; +import com.yahoo.component.Vtag; +import com.yahoo.config.provision.CloudAccount; import com.yahoo.config.provision.ClusterResources; +import com.yahoo.config.provision.IntRange; import com.yahoo.slime.Cursor; import com.yahoo.slime.Slime; +import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.applications.Application; @@ -15,6 +19,7 @@ import com.yahoo.vespa.hosted.provision.autoscale.Limits; import com.yahoo.vespa.hosted.provision.autoscale.Load; import java.net.URI; +import java.util.Comparator; import java.util.List; /** @@ -40,6 +45,13 @@ public class ApplicationSerializer { URI applicationUri) { object.setString("url", applicationUri.toString()); object.setString("id", application.id().toFullString()); + Version version = applicationNodes.stream() + .map(node -> node.status().vespaVersion() + .orElse(node.allocation().get().membership().cluster().vespaVersion())) + .min(Comparator.naturalOrder()) + .orElse(Vtag.currentVersion); + object.setString("version", version.toFullString()); + object.setString("cloudAccount", applicationNodes.stream().findFirst().map(Node::cloudAccount).orElse(CloudAccount.empty).value()); clustersToSlime(application, applicationNodes, nodeRepository, object.setObject("clusters")); } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java index dcbac44a37f..3f2d7112224 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java @@ -253,6 +253,63 @@ public class OsVersionsTest { } @Test + public void upgrade_by_retiring_is_limited_by_group_membership() { + var versions = new OsVersions(tester.nodeRepository(), Cloud.builder().dynamicProvisioning(true).build(), + Optional.ofNullable(tester.hostProvisioner())); + int hostCount = 7; + int app1GroupCount = 2; + setMaxActiveUpgrades(hostCount); + ApplicationId app1 = ApplicationId.from("t1", "a1", "i1"); + ApplicationId app2 = ApplicationId.from("t2", "a2", "i2"); + provisionInfraApplication(hostCount, NodeType.host); + deployApplication(app1, app1GroupCount); + deployApplication(app2); + Supplier<NodeList> hosts = () -> tester.nodeRepository().nodes().list() + .nodeType(NodeType.host) + .not().state(Node.State.deprovisioned); + + // All hosts are on initial version + var version0 = Version.fromString("8.0"); + versions.setTarget(NodeType.host, version0, false); + setCurrentVersion(hosts.get().asList(), version0); + + // New version is triggered + var version1 = Version.fromString("8.5"); + versions.setTarget(NodeType.host, version1, false); + versions.resumeUpgradeOf(NodeType.host, true); + { + // At most one node per group is retired + NodeList allNodes = tester.nodeRepository().nodes().list().not().state(Node.State.deprovisioned); + assertEquals(hostCount - 1, allNodes.nodeType(NodeType.host).deprovisioning().size()); + assertEquals(1, allNodes.owner(app1).retiring().group(0).size()); + assertEquals(0, allNodes.owner(app1).retiring().group(1).size()); + assertEquals(2, allNodes.owner(app2).retiring().size()); + + // Hosts complete reprovisioning + NodeList emptyHosts = allNodes.deprovisioning().nodeType(NodeType.host) + .matching(h -> allNodes.childrenOf(h).isEmpty()); + completeReprovisionOf(emptyHosts.asList(), NodeType.host); + replaceNodes(app1, app1GroupCount); + replaceNodes(app2); + completeReprovisionOf(hosts.get().deprovisioning().asList(), NodeType.host); + } + { + // Last host/group is retired + versions.resumeUpgradeOf(NodeType.host, true); + NodeList allNodes = tester.nodeRepository().nodes().list().not().state(Node.State.deprovisioned); + assertEquals(1, allNodes.nodeType(NodeType.host).deprovisioning().size()); + assertEquals(0, allNodes.owner(app1).retiring().group(0).size()); + assertEquals(1, allNodes.owner(app1).retiring().group(1).size()); + assertEquals(0, allNodes.owner(app2).retiring().size()); + replaceNodes(app1, app1GroupCount); + completeReprovisionOf(hosts.get().deprovisioning().asList(), NodeType.host); + } + NodeList allHosts = hosts.get(); + assertEquals(0, allHosts.deprovisioning().size()); + assertEquals(allHosts.size(), allHosts.onOsVersion(version1).size()); + } + + @Test public void upgrade_by_rebuilding() { var versions = new OsVersions(tester.nodeRepository(), Cloud.defaultCloud(), Optional.ofNullable(tester.hostProvisioner())); setMaxActiveUpgrades(1); @@ -547,24 +604,32 @@ public class OsVersionsTest { } private void deployApplication(ApplicationId application) { + deployApplication(application, 1); + } + + private void deployApplication(ApplicationId application, int groups) { ClusterSpec contentSpec = ClusterSpec.request(ClusterSpec.Type.content, ClusterSpec.Id.from("content1")).vespaVersion("7").build(); - List<HostSpec> hostSpecs = tester.prepare(application, contentSpec, 2, 1, new NodeResources(4, 8, 100, 0.3)); + List<HostSpec> hostSpecs = tester.prepare(application, contentSpec, 2, groups, new NodeResources(4, 8, 100, 0.3)); tester.activate(application, hostSpecs); } - private void replaceNodes(ApplicationId application) { + private void replaceNodes(ApplicationId application, int groups) { // Deploy to retire nodes - deployApplication(application); + deployApplication(application, groups); NodeList retired = tester.nodeRepository().nodes().list().owner(application).retired(); assertFalse("At least one node is retired", retired.isEmpty()); tester.nodeRepository().nodes().setRemovable(retired, false); // Redeploy to deactivate removable nodes and allocate new ones - deployApplication(application); + deployApplication(application, groups); tester.nodeRepository().nodes().list(Node.State.inactive).owner(application) .forEach(node -> tester.nodeRepository().nodes().removeRecursively(node, true)); } + private void replaceNodes(ApplicationId application) { + replaceNodes(application, 1); + } + private NodeList deprovisioningChildrenOf(Node parent) { return tester.nodeRepository().nodes().list() .childrenOf(parent) diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application1.json b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application1.json index 7b2cf1dc8e4..d5bbc648ed8 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application1.json +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application1.json @@ -1,6 +1,8 @@ { "url" : "http://localhost:8080/nodes/v2/applications/tenant1.application1.instance1", "id" : "tenant1.application1.instance1", + "cloudAccount": "aws:111222333444", + "version": "5.104.142", "clusters" : { "id1" : { "type": "container", diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application2.json b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application2.json index 10173089f75..abd76f9be96 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application2.json +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/application2.json @@ -1,6 +1,8 @@ { "url": "http://localhost:8080/nodes/v2/applications/tenant2.application2.instance2", "id": "tenant2.application2.instance2", + "cloudAccount": "aws:111222333444", + "version": "6.42.0", "clusters": { "id2": { "type": "content", diff --git a/parent/pom.xml b/parent/pom.xml index 82bf09ea92b..c10c74c0cb1 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -317,7 +317,7 @@ --> <groupId>org.openrewrite.maven</groupId> <artifactId>rewrite-maven-plugin</artifactId> - <version>5.20.0</version> + <version>5.21.0</version> <configuration> <activeRecipes> <recipe>org.openrewrite.java.testing.junit5.JUnit5BestPractices</recipe> @@ -327,7 +327,7 @@ <dependency> <groupId>org.openrewrite.recipe</groupId> <artifactId>rewrite-testing-frameworks</artifactId> - <version>2.3.0</version> + <version>2.3.1</version> </dependency> </dependencies> </plugin> @@ -1172,7 +1172,7 @@ See pluginManagement of rewrite-maven-plugin for more details --> <groupId>org.openrewrite.recipe</groupId> <artifactId>rewrite-recipe-bom</artifactId> - <version>2.6.2</version> + <version>2.6.3</version> <type>pom</type> <scope>import</scope> </dependency> diff --git a/searchcore/src/tests/proton/documentmetastore/lid_allocator/lid_allocator_test.cpp b/searchcore/src/tests/proton/documentmetastore/lid_allocator/lid_allocator_test.cpp index e136e491f05..4aefa10f5f2 100644 --- a/searchcore/src/tests/proton/documentmetastore/lid_allocator/lid_allocator_test.cpp +++ b/searchcore/src/tests/proton/documentmetastore/lid_allocator/lid_allocator_test.cpp @@ -180,9 +180,10 @@ TEST_F(LidAllocatorTest, whitelist_blueprint_can_maximize_relative_estimate) activate_lids({ 1, 2, 3, 4 }, true); // the number of hits are overestimated based on the number of // documents that could be active (100 in this test fixture) - EXPECT_EQ(make_whitelist_blueprint(1000)->estimate(), 0.1); - EXPECT_EQ(make_whitelist_blueprint(200)->estimate(), 0.5); - EXPECT_EQ(make_whitelist_blueprint(5)->estimate(), 1.0); + // NOTE: optimize must be called in order to calculate the relative estimate + EXPECT_EQ(Blueprint::optimize(make_whitelist_blueprint(1000))->estimate(), 0.1); + EXPECT_EQ(Blueprint::optimize(make_whitelist_blueprint(200))->estimate(), 0.5); + EXPECT_EQ(Blueprint::optimize(make_whitelist_blueprint(5))->estimate(), 1.0); } class LidAllocatorPerformanceTest : public LidAllocatorTest, diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp index f81b47583b9..30ba7d320f7 100644 --- a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp +++ b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp @@ -627,7 +627,7 @@ AttributeWriter::internalPut(SerialNum serialNum, const Document &doc, DocumentI bool allAttributes, OnWriteDoneType onWriteDone) { for (const auto &wc : _writeContexts) { - if (wc.use_two_phase_put()) { + if (allAttributes && wc.use_two_phase_put()) { assert(wc.getFields().size() == 1); wc.consider_build_field_paths(doc); auto prepare_task = std::make_unique<PreparePutTask>(serialNum, lid, wc, doc); diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_phase_limiter.cpp b/searchcore/src/vespa/searchcore/proton/matching/match_phase_limiter.cpp index 784ce649c5f..98c5daa1415 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_phase_limiter.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/match_phase_limiter.cpp @@ -102,9 +102,8 @@ do_limit(AttributeLimiter &limiter_factory, SearchIterator::UP search, double ma return search; } -// When hitrate is below 1% limiting the query is often far more expensive than not. -// TODO This limit should probably be a lot higher. -constexpr double MIN_HIT_RATE_LIMIT = 0.01; +// When hitrate is below 0.2% limiting the query is often far more expensive than not. +constexpr double MIN_HIT_RATE_LIMIT = 0.002; } // namespace proton::matching::<unnamed> diff --git a/searchcore/src/vespa/searchcore/proton/matching/query.cpp b/searchcore/src/vespa/searchcore/proton/matching/query.cpp index a93e8fbbddc..5ade0a44b8a 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/query.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/query.cpp @@ -200,8 +200,7 @@ Query::reserveHandles(const IRequestContext & requestContext, ISearchContext &co void Query::optimize(bool sort_by_cost) { - (void) sort_by_cost; - _blueprint = Blueprint::optimize(std::move(_blueprint), sort_by_cost); + _blueprint = Blueprint::optimize_and_sort(std::move(_blueprint), true, sort_by_cost); LOG(debug, "optimized blueprint:\n%s\n", _blueprint->asString().c_str()); } @@ -223,7 +222,7 @@ Query::handle_global_filter(const IRequestContext & requestContext, uint32_t doc } // optimized order may change after accounting for global filter: trace.addEvent(5, "Optimize query execution plan to account for global filter"); - _blueprint = Blueprint::optimize(std::move(_blueprint), sort_by_cost); + _blueprint = Blueprint::optimize_and_sort(std::move(_blueprint), true, sort_by_cost); LOG(debug, "blueprint after handle_global_filter:\n%s\n", _blueprint->asString().c_str()); // strictness may change if optimized order changed: fetchPostings(ExecuteInfo::create(true, 1.0, requestContext.getDoom(), requestContext.thread_bundle())); diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt index 5628db99171..7bec70c00cb 100644 --- a/searchlib/CMakeLists.txt +++ b/searchlib/CMakeLists.txt @@ -190,6 +190,7 @@ vespa_define_module( src/tests/postinglistbm src/tests/predicate src/tests/query + src/tests/query/streaming src/tests/queryeval src/tests/queryeval/blueprint src/tests/queryeval/dot_product diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 591f0eb8b37..97aa42f79c9 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -177,7 +177,7 @@ List<ReferenceNode> featureList() : ReferenceNode exp; } { - ( ( exp = feature() { ret.add(exp); } )+ <EOF> ) + ( ( exp = feature() { ret.add(exp); } )* <EOF> ) { return ret; } } diff --git a/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp b/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp index f612bdda87f..6e479e1d9db 100644 --- a/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp +++ b/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp @@ -430,7 +430,8 @@ BitVectorTest::test(BasicType bt, CollectionType ct, const vespalib::string &pre sc = getSearch<VectorType>(tv, filter); checkSearch(v, std::move(sc), 2, 1022, 205, !filter, true); const auto* dww = v->as_docid_with_weight_posting_store(); - if (dww != nullptr) { + if ((dww != nullptr) && (bt == BasicType::STRING)) { + // This way of doing lookup is only supported by string attributes. auto lres = dww->lookup(getSearchStr<VectorType>(), dww->get_dictionary_snapshot()); using DWSI = search::queryeval::DocidWithWeightSearchIterator; TermFieldMatchData md; diff --git a/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp b/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp index a94ca9d890e..87b771af8e6 100644 --- a/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp +++ b/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp @@ -26,6 +26,7 @@ using LookupKey = IDirectPostingStore::LookupKey; struct IntegerKey : public LookupKey { int64_t _value; IntegerKey(int64_t value_in) : _value(value_in) {} + IntegerKey(const vespalib::string&) : _value() { abort(); } vespalib::stringref asString() const override { abort(); } bool asInteger(int64_t& value) const override { value = _value; return true; } }; @@ -33,6 +34,7 @@ struct IntegerKey : public LookupKey { struct StringKey : public LookupKey { vespalib::string _value; StringKey(int64_t value_in) : _value(std::to_string(value_in)) {} + StringKey(const vespalib::string& value_in) : _value(value_in) {} vespalib::stringref asString() const override { return _value; } bool asInteger(int64_t&) const override { abort(); } }; @@ -78,6 +80,10 @@ populate_attribute(AttributeType& attr, const std::vector<DataType>& values) for (auto docid : range(300, 128)) { attr.update(docid, values[3]); } + if (values.size() > 4) { + attr.update(40, values[4]); + attr.update(41, values[5]); + } attr.commit(true); } @@ -93,7 +99,7 @@ make_attribute(CollectionType col_type, BasicType type, bool field_is_filter) auto attr = test::AttributeBuilder(field_name, cfg).docs(num_docs).get(); if (type == BasicType::STRING) { populate_attribute<StringAttribute, vespalib::string>(dynamic_cast<StringAttribute&>(*attr), - {"1", "3", "100", "300"}); + {"1", "3", "100", "300", "foo", "Foo"}); } else { populate_attribute<IntegerAttribute, int64_t>(dynamic_cast<IntegerAttribute&>(*attr), {1, 3, 100, 300}); @@ -223,15 +229,16 @@ public: tfmd.tagAsNotNeeded(); } } - template <typename BlueprintType> - void add_term_helper(BlueprintType& b, int64_t term_value) { + template <typename BlueprintType, typename TermType> + void add_term_helper(BlueprintType& b, TermType term_value) { if (integer_type) { b.addTerm(IntegerKey(term_value), 1, estimate); } else { b.addTerm(StringKey(term_value), 1, estimate); } } - void add_term(int64_t term_value) { + template <typename TermType> + void add_term(TermType term_value) { if (single_type) { if (in_operator) { add_term_helper(dynamic_cast<SingleInBlueprintType&>(*blueprint), term_value); @@ -246,8 +253,18 @@ public: } } } - std::unique_ptr<SearchIterator> create_leaf_search() const { - return blueprint->createLeafSearch(tfmda, true); + void add_terms(const std::vector<int64_t>& term_values) { + for (auto value : term_values) { + add_term(value); + } + } + void add_terms(const std::vector<vespalib::string>& term_values) { + for (auto value : term_values) { + add_term(value); + } + } + std::unique_ptr<SearchIterator> create_leaf_search(bool strict = true) const { + return blueprint->createLeafSearch(tfmda, strict); } vespalib::string resolve_iterator_with_unpack() const { if (in_operator) { @@ -262,7 +279,7 @@ expect_hits(const Docids& exp_docids, SearchIterator& itr) { SimpleResult exp(exp_docids); SimpleResult act; - act.search(itr); + act.search(itr, doc_id_limit); EXPECT_EQ(exp, act); } @@ -294,8 +311,7 @@ INSTANTIATE_TEST_SUITE_P(DefaultInstantiation, TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_none_filter_field) { setup(false, true); - add_term(1); - add_term(3); + add_terms({1, 3}); auto itr = create_leaf_search(); EXPECT_THAT(itr->asString(), StartsWith(resolve_iterator_with_unpack())); expect_hits({10, 30, 31}, *itr); @@ -307,8 +323,7 @@ TEST_P(DirectMultiTermBlueprintTest, bitvectors_used_instead_of_btree_iterators_ if (!in_operator) { return; } - add_term(1); - add_term(100); + add_terms({1, 100}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 2); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -322,8 +337,7 @@ TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_instead_of_bitvectors_ if (in_operator) { return; } - add_term(1); - add_term(100); + add_terms({1, 100}); auto itr = create_leaf_search(); EXPECT_THAT(itr->asString(), StartsWith(iterator_unpack_docid_and_weights)); expect_hits(concat({10}, range(100, 128)), *itr); @@ -332,10 +346,7 @@ TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_instead_of_bitvectors_ TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_filter_field) { setup(true, true); - add_term(1); - add_term(3); - add_term(100); - add_term(300); + add_terms({1, 3, 100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 3); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -347,8 +358,7 @@ TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_fil TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field) { setup(true, true); - add_term(100); - add_term(300); + add_terms({100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 2); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -359,8 +369,7 @@ TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field) TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_filter_field_when_ranking_not_needed) { setup(true, false); - add_term(1); - add_term(3); + add_terms({1, 3}); auto itr = create_leaf_search(); EXPECT_THAT(itr->asString(), StartsWith(iterator_unpack_none)); expect_hits({10, 30, 31}, *itr); @@ -369,10 +378,7 @@ TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_filter_field_when_ TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_filter_field_when_ranking_not_needed) { setup(true, false); - add_term(1); - add_term(3); - add_term(100); - add_term(300); + add_terms({1, 3, 100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 3); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -384,8 +390,7 @@ TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_fil TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field_when_ranking_not_needed) { setup(true, false); - add_term(100); - add_term(300); + add_terms({100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 2); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -393,4 +398,41 @@ TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field_when_ expect_hits(concat(range(100, 128), range(300, 128)), *itr); } +TEST_P(DirectMultiTermBlueprintTest, hash_filter_used_for_non_strict_iterator_with_10_or_more_terms) +{ + setup(true, true); + if (!single_type) { + return; + } + add_terms({1, 3, 3, 3, 3, 3, 3, 3, 3, 3}); + auto itr = create_leaf_search(false); + EXPECT_THAT(itr->asString(), StartsWith("search::attribute::MultiTermHashFilter")); + expect_hits({10, 30, 31}, *itr); +} + +TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_non_strict_iterator_with_9_or_less_terms) +{ + setup(true, true); + if (!single_type) { + return; + } + add_terms({1, 3, 3, 3, 3, 3, 3, 3, 3}); + auto itr = create_leaf_search(false); + EXPECT_THAT(itr->asString(), StartsWith(iterator_unpack_docid)); + expect_hits({10, 30, 31}, *itr); +} + +TEST_P(DirectMultiTermBlueprintTest, hash_filter_with_string_folding_used_for_non_strict_iterator) +{ + setup(true, true); + if (!single_type || integer_type) { + return; + } + // "foo" matches documents with "foo" (40) and "Foo" (41). + add_terms({"foo", "3", "3", "3", "3", "3", "3", "3", "3", "3"}); + auto itr = create_leaf_search(false); + EXPECT_THAT(itr->asString(), StartsWith("search::attribute::MultiTermHashFilter")); + expect_hits({30, 31, 40, 41}, *itr); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/attribute/direct_posting_store/direct_posting_store_test.cpp b/searchlib/src/tests/attribute/direct_posting_store/direct_posting_store_test.cpp index c1e12580559..2ddd211bd12 100644 --- a/searchlib/src/tests/attribute/direct_posting_store/direct_posting_store_test.cpp +++ b/searchlib/src/tests/attribute/direct_posting_store/direct_posting_store_test.cpp @@ -142,7 +142,11 @@ TEST(DirectPostingStoreApiTest, attributes_do_not_support_IDocidPostingStore_int TEST(DirectPostingStoreApiTest, attributes_support_IDocidWithWeightPostingStore_interface) { expect_docid_with_weight_posting_store(BasicType::INT64, CollectionType::WSET, true); + expect_docid_with_weight_posting_store(BasicType::INT32, CollectionType::WSET, true); expect_docid_with_weight_posting_store(BasicType::STRING, CollectionType::WSET, true); + expect_docid_with_weight_posting_store(BasicType::INT64, CollectionType::ARRAY, true); + expect_docid_with_weight_posting_store(BasicType::INT32, CollectionType::ARRAY, true); + expect_docid_with_weight_posting_store(BasicType::STRING, CollectionType::ARRAY, true); } TEST(DirectPostingStoreApiTest, attributes_do_not_support_IDocidWithWeightPostingStore_interface) { @@ -150,13 +154,11 @@ TEST(DirectPostingStoreApiTest, attributes_do_not_support_IDocidWithWeightPostin expect_not_docid_with_weight_posting_store(BasicType::INT64, CollectionType::ARRAY, false); expect_not_docid_with_weight_posting_store(BasicType::INT64, CollectionType::WSET, false); expect_not_docid_with_weight_posting_store(BasicType::INT64, CollectionType::SINGLE, true); - expect_not_docid_with_weight_posting_store(BasicType::INT64, CollectionType::ARRAY, true); expect_not_docid_with_weight_posting_store(BasicType::STRING, CollectionType::SINGLE, false); expect_not_docid_with_weight_posting_store(BasicType::STRING, CollectionType::ARRAY, false); expect_not_docid_with_weight_posting_store(BasicType::STRING, CollectionType::WSET, false); expect_not_docid_with_weight_posting_store(BasicType::STRING, CollectionType::SINGLE, true); - expect_not_docid_with_weight_posting_store(BasicType::STRING, CollectionType::ARRAY, true); - expect_not_docid_with_weight_posting_store(BasicType::INT32, CollectionType::WSET, true); + expect_not_docid_with_weight_posting_store(BasicType::DOUBLE, CollectionType::ARRAY, true); expect_not_docid_with_weight_posting_store(BasicType::DOUBLE, CollectionType::WSET, true); } diff --git a/searchlib/src/tests/attribute/searchable/attribute_searchable_adapter_test.cpp b/searchlib/src/tests/attribute/searchable/attribute_searchable_adapter_test.cpp index ecc03ac54c5..3b346601245 100644 --- a/searchlib/src/tests/attribute/searchable/attribute_searchable_adapter_test.cpp +++ b/searchlib/src/tests/attribute/searchable/attribute_searchable_adapter_test.cpp @@ -473,35 +473,6 @@ TEST("require that attribute dot product can produce no hits") { } } -TEST("require that direct attribute iterators work") { - for (int i = 0; i <= 0x3; ++i) { - bool fast_search = ((i & 0x1) != 0); - bool strict = ((i & 0x2) != 0); - MyAttributeManager attribute_manager = make_weighted_string_attribute_manager(fast_search); - SimpleStringTerm empty_node("notfoo", "", 0, Weight(1)); - Result empty_result = do_search(attribute_manager, empty_node, strict); - EXPECT_EQUAL(0u, empty_result.hits.size()); - SimpleStringTerm node("foo", "", 0, Weight(1)); - Result result = do_search(attribute_manager, node, strict); - if (fast_search) { - EXPECT_EQUAL(3u, result.est_hits); - EXPECT_TRUE(result.has_minmax); - EXPECT_EQUAL(100, result.min_weight); - EXPECT_EQUAL(1000, result.max_weight); - EXPECT_TRUE(result.iterator_dump.find("DocidWithWeightSearchIterator") != vespalib::string::npos); - } else { - EXPECT_EQUAL(num_docs, result.est_hits); - EXPECT_FALSE(result.has_minmax); - EXPECT_TRUE(result.iterator_dump.find("DocidWithWeightSearchIterator") == vespalib::string::npos); - } - ASSERT_EQUAL(3u, result.hits.size()); - EXPECT_FALSE(result.est_empty); - EXPECT_EQUAL(20u, result.hits[0].docid); - EXPECT_EQUAL(40u, result.hits[1].docid); - EXPECT_EQUAL(50u, result.hits[2].docid); - } -} - TEST("require that single weighted set turns filter on filter fields") { bool fast_search = true; bool strict = true; diff --git a/searchlib/src/tests/nearsearch/nearsearch_test.cpp b/searchlib/src/tests/nearsearch/nearsearch_test.cpp index 3751fc93cea..95701e59444 100644 --- a/searchlib/src/tests/nearsearch/nearsearch_test.cpp +++ b/searchlib/src/tests/nearsearch/nearsearch_test.cpp @@ -215,7 +215,7 @@ bool Test::testNearSearch(MyQuery &query, uint32_t matchId) { LOG(info, "testNearSearch(%d)", matchId); - search::queryeval::IntermediateBlueprint *near_b = 0; + search::queryeval::IntermediateBlueprint *near_b = nullptr; if (query.isOrdered()) { near_b = new search::queryeval::ONearBlueprint(query.getWindow()); } else { @@ -228,9 +228,10 @@ Test::testNearSearch(MyQuery &query, uint32_t matchId) layout.allocTermField(fieldId); near_b->addChild(query.getTerm(i).make_blueprint(fieldId, i)); } - search::fef::MatchData::UP md(layout.createMatchData()); - + bp->setDocIdLimit(1000); + bp = search::queryeval::Blueprint::optimize_and_sort(std::move(bp), true, true); bp->fetchPostings(search::queryeval::ExecuteInfo::TRUE); + search::fef::MatchData::UP md(layout.createMatchData()); search::queryeval::SearchIterator::UP near = bp->createSearch(*md, true); near->initFullRange(); bool foundMatch = false; diff --git a/searchlib/src/tests/query/streaming/CMakeLists.txt b/searchlib/src/tests/query/streaming/CMakeLists.txt new file mode 100644 index 00000000000..257ad24854d --- /dev/null +++ b/searchlib/src/tests/query/streaming/CMakeLists.txt @@ -0,0 +1,37 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +vespa_add_executable(searchlib_query_streaming_hit_iterator_test_app TEST + SOURCES + hit_iterator_test.cpp + DEPENDS + searchlib + GTest::gtest +) +vespa_add_test(NAME searchlib_query_streaming_hit_iterator_test_app COMMAND searchlib_query_streaming_hit_iterator_test_app) + +vespa_add_executable(searchlib_query_streaming_hit_iterator_pack_test_app TEST + SOURCES + hit_iterator_pack_test.cpp + DEPENDS + searchlib + GTest::gtest +) +vespa_add_test(NAME searchlib_query_streaming_hit_iterator_pack_test_app COMMAND searchlib_query_streaming_hit_iterator_pack_test_app) + +vespa_add_executable(searchlib_query_streaming_same_element_query_node_test_app TEST + SOURCES + same_element_query_node_test.cpp + DEPENDS + searchlib + GTest::gtest +) +vespa_add_test(NAME searchlib_query_streaming_same_element_query_node_test_app COMMAND searchlib_query_streaming_same_element_query_node_test_app) + +vespa_add_executable(searchlib_query_streaming_phrase_query_node_test_app TEST + SOURCES + phrase_query_node_test.cpp + DEPENDS + searchlib + GTest::gtest +) +vespa_add_test(NAME searchlib_query_streaming_phrase_query_node_test_app COMMAND searchlib_query_streaming_phrase_query_node_test_app) diff --git a/searchlib/src/tests/query/streaming/hit_iterator_pack_test.cpp b/searchlib/src/tests/query/streaming/hit_iterator_pack_test.cpp new file mode 100644 index 00000000000..7d7d8307920 --- /dev/null +++ b/searchlib/src/tests/query/streaming/hit_iterator_pack_test.cpp @@ -0,0 +1,44 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchlib/query/streaming/hit_iterator_pack.h> +#include <vespa/vespalib/gtest/gtest.h> + +using search::streaming::HitIterator; +using search::streaming::HitIteratorPack; +using search::streaming::QueryNodeList; +using search::streaming::QueryTerm; +using search::streaming::QueryNodeResultBase; + +using FieldElement = HitIterator::FieldElement; + +TEST(HitIteratorPackTest, seek_to_matching_field_element) +{ + QueryNodeList qnl; + auto qt = std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "7", "", QueryTerm::Type::WORD); + qt->add(11, 0, 10, 0); + qt->add(11, 0, 10, 5); + qt->add(11, 1, 12, 0); + qt->add(11, 1, 12, 0); + qt->add(12, 1, 13, 0); + qt->add(12, 1, 13, 0); + qnl.emplace_back(std::move(qt)); + qt = std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "8", "", QueryTerm::Type::WORD); + qt->add(2, 0, 4, 0); + qt->add(11, 0, 10, 0); + qt->add(12, 1, 13, 0); + qt->add(12, 2, 14, 0); + qnl.emplace_back(std::move(qt)); + HitIteratorPack itr_pack(qnl); + EXPECT_TRUE(itr_pack.all_valid()); + EXPECT_TRUE(itr_pack.seek_to_matching_field_element()); + EXPECT_EQ(FieldElement(11, 0), itr_pack.get_field_element_ref()); + EXPECT_TRUE(itr_pack.seek_to_matching_field_element()); + EXPECT_EQ(FieldElement(11, 0), itr_pack.get_field_element_ref()); + ++itr_pack.get_field_element_ref().second; + EXPECT_TRUE(itr_pack.seek_to_matching_field_element()); + EXPECT_EQ(FieldElement(12, 1), itr_pack.get_field_element_ref()); + ++itr_pack.get_field_element_ref().second; + EXPECT_FALSE(itr_pack.seek_to_matching_field_element()); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/query/streaming/hit_iterator_test.cpp b/searchlib/src/tests/query/streaming/hit_iterator_test.cpp new file mode 100644 index 00000000000..a9588ea3d6c --- /dev/null +++ b/searchlib/src/tests/query/streaming/hit_iterator_test.cpp @@ -0,0 +1,122 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchlib/query/streaming/hit_iterator.h> +#include <vespa/vespalib/gtest/gtest.h> + +using search::streaming::Hit; +using search::streaming::HitList; +using search::streaming::HitIterator; + +using FieldElement = HitIterator::FieldElement; + +namespace { + +HitList +make_hit_list() +{ + HitList hl; + hl.emplace_back(11, 0, 10, 0); + hl.emplace_back(11, 0, 10, 5); + hl.emplace_back(11, 1, 12, 0); + hl.emplace_back(11, 1, 12, 7); + hl.emplace_back(12, 1, 13, 0); + hl.emplace_back(12, 1, 13, 9); + return hl; +} + +void +check_seek_to_field_elem(HitIterator& it, const FieldElement& field_element, const Hit* exp_ptr, const vespalib::string& label) +{ + SCOPED_TRACE(label); + EXPECT_TRUE(it.seek_to_field_element(field_element)); + EXPECT_TRUE(it.valid()); + EXPECT_EQ(exp_ptr, &*it); +} + +void +check_seek_to_field_elem_failure(HitIterator& it, const FieldElement& field_element, const vespalib::string& label) +{ + SCOPED_TRACE(label); + EXPECT_FALSE(it.seek_to_field_element(field_element)); + EXPECT_FALSE(it.valid()); +} + +void +check_step_in_field_element(HitIterator& it, FieldElement& field_element, bool exp_success, const Hit* exp_ptr, const vespalib::string& label) +{ + SCOPED_TRACE(label); + EXPECT_EQ(exp_success, it.step_in_field_element(field_element)); + if (exp_ptr) { + EXPECT_TRUE(it.valid()); + EXPECT_EQ(it.get_field_element(), field_element); + EXPECT_EQ(exp_ptr, &*it); + } else { + EXPECT_FALSE(it.valid()); + } +} + +void +check_seek_in_field_element(HitIterator& it, uint32_t position, FieldElement& field_element, bool exp_success, const Hit* exp_ptr, const vespalib::string& label) +{ + SCOPED_TRACE(label); + EXPECT_EQ(exp_success, it.seek_in_field_element(position, field_element)); + if (exp_ptr) { + EXPECT_TRUE(it.valid()); + EXPECT_EQ(it.get_field_element(), field_element); + EXPECT_EQ(exp_ptr, &*it); + } else { + EXPECT_FALSE(it.valid()); + } +} + +} + +TEST(HitITeratorTest, seek_to_field_element) +{ + auto hl = make_hit_list(); + HitIterator it(hl); + EXPECT_TRUE(it.valid()); + EXPECT_EQ(&hl[0], &*it); + check_seek_to_field_elem(it, FieldElement(0, 0), &hl[0], "(0, 0)"); + check_seek_to_field_elem(it, FieldElement(11, 0), &hl[0], "(11, 0)"); + check_seek_to_field_elem(it, FieldElement(11, 1), &hl[2], "(11, 1)"); + check_seek_to_field_elem(it, FieldElement(11, 2), &hl[4], "(11, 2)"); + check_seek_to_field_elem(it, FieldElement(12, 0), &hl[4], "(12, 0)"); + check_seek_to_field_elem(it, FieldElement(12, 1), &hl[4], "(12, 1)"); + check_seek_to_field_elem_failure(it, FieldElement(12, 2), "(12, 2)"); + check_seek_to_field_elem_failure(it, FieldElement(13, 0), "(13, 0)"); +} + +TEST(HitIteratorTest, step_in_field_element) +{ + auto hl = make_hit_list(); + HitIterator it(hl); + auto field_element = it.get_field_element(); + check_step_in_field_element(it, field_element, true, &hl[1], "1"); + check_step_in_field_element(it, field_element, false, &hl[2], "2"); + check_step_in_field_element(it, field_element, true, &hl[3], "3"); + check_step_in_field_element(it, field_element, false, &hl[4], "4"); + check_step_in_field_element(it, field_element, true, &hl[5], "5"); + check_step_in_field_element(it, field_element, false, nullptr, "end"); +} + +TEST(hitIteratorTest, seek_in_field_elem) +{ + auto hl = make_hit_list(); + HitIterator it(hl); + auto field_element = it.get_field_element(); + check_seek_in_field_element(it, 0, field_element, true, &hl[0], "0a"); + check_seek_in_field_element(it, 2, field_element, true, &hl[1], "2"); + check_seek_in_field_element(it, 5, field_element, true, &hl[1], "5"); + check_seek_in_field_element(it, 6, field_element, false, &hl[2], "6"); + check_seek_in_field_element(it, 0, field_element, true, &hl[2], "0b"); + check_seek_in_field_element(it, 1, field_element, true, &hl[3], "1"); + check_seek_in_field_element(it, 7, field_element, true, &hl[3], "7"); + check_seek_in_field_element(it, 8, field_element, false, &hl[4], "8"); + check_seek_in_field_element(it, 0, field_element, true, &hl[4], "0c"); + check_seek_in_field_element(it, 3, field_element, true, &hl[5], "3"); + check_seek_in_field_element(it, 9, field_element, true, &hl[5], "9"); + check_seek_in_field_element(it, 10, field_element, false, nullptr, "end"); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/query/streaming/phrase_query_node_test.cpp b/searchlib/src/tests/query/streaming/phrase_query_node_test.cpp new file mode 100644 index 00000000000..5caae8d6e97 --- /dev/null +++ b/searchlib/src/tests/query/streaming/phrase_query_node_test.cpp @@ -0,0 +1,95 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchlib/query/streaming/phrase_query_node.h> +#include <vespa/searchlib/query/streaming/queryterm.h> +#include <vespa/searchlib/query/tree/querybuilder.h> +#include <vespa/searchlib/query/tree/simplequery.h> +#include <vespa/searchlib/query/tree/stackdumpcreator.h> +#include <vespa/vespalib/gtest/gtest.h> + +using search::query::QueryBuilder; +using search::query::Node; +using search::query::SimpleQueryNodeTypes; +using search::query::StackDumpCreator; +using search::query::Weight; +using search::streaming::HitList; +using search::streaming::PhraseQueryNode; +using search::streaming::Query; +using search::streaming::QueryTerm; +using search::streaming::QueryNodeRefList; +using search::streaming::QueryNodeResultFactory; +using search::streaming::QueryTermList; + +TEST(PhraseQueryNodeTest, test_phrase_evaluate) +{ + QueryBuilder<SimpleQueryNodeTypes> builder; + builder.addPhrase(3, "", 0, Weight(0)); + { + builder.addStringTerm("a", "", 0, Weight(0)); + builder.addStringTerm("b", "", 0, Weight(0)); + builder.addStringTerm("c", "", 0, Weight(0)); + } + Node::UP node = builder.build(); + vespalib::string stackDump = StackDumpCreator::create(*node); + QueryNodeResultFactory empty; + Query q(empty, stackDump); + QueryNodeRefList phrases; + q.getPhrases(phrases); + QueryTermList terms; + q.getLeaves(terms); + for (QueryTerm * qt : terms) { + qt->resizeFieldId(1); + } + + // field 0 + terms[0]->add(0, 0, 1, 0); + terms[1]->add(0, 0, 1, 1); + terms[2]->add(0, 0, 1, 2); + terms[0]->add(0, 0, 1, 7); + terms[1]->add(0, 1, 1, 8); + terms[2]->add(0, 0, 1, 9); + // field 1 + terms[0]->add(1, 0, 1, 4); + terms[1]->add(1, 0, 1, 5); + terms[2]->add(1, 0, 1, 6); + // field 2 (not complete match) + terms[0]->add(2, 0, 1, 1); + terms[1]->add(2, 0, 1, 2); + terms[2]->add(2, 0, 1, 4); + // field 3 + terms[0]->add(3, 0, 1, 0); + terms[1]->add(3, 0, 1, 1); + terms[2]->add(3, 0, 1, 2); + // field 4 (not complete match) + terms[0]->add(4, 0, 1, 1); + terms[1]->add(4, 0, 1, 2); + // field 5 (not complete match) + terms[0]->add(5, 0, 1, 2); + terms[1]->add(5, 0, 1, 1); + terms[2]->add(5, 0, 1, 0); + HitList hits; + auto * p = static_cast<PhraseQueryNode *>(phrases[0]); + p->evaluateHits(hits); + ASSERT_EQ(3u, hits.size()); + EXPECT_EQ(0u, hits[0].field_id()); + EXPECT_EQ(0u, hits[0].element_id()); + EXPECT_EQ(2u, hits[0].position()); + EXPECT_EQ(1u, hits[1].field_id()); + EXPECT_EQ(0u, hits[1].element_id()); + EXPECT_EQ(6u, hits[1].position()); + EXPECT_EQ(3u, hits[2].field_id()); + EXPECT_EQ(0u, hits[2].element_id()); + EXPECT_EQ(2u, hits[2].position()); + ASSERT_EQ(4u, p->getFieldInfoSize()); + EXPECT_EQ(0u, p->getFieldInfo(0).getHitOffset()); + EXPECT_EQ(1u, p->getFieldInfo(0).getHitCount()); + EXPECT_EQ(1u, p->getFieldInfo(1).getHitOffset()); + EXPECT_EQ(1u, p->getFieldInfo(1).getHitCount()); + EXPECT_EQ(0u, p->getFieldInfo(2).getHitOffset()); // invalid, but will never be used + EXPECT_EQ(0u, p->getFieldInfo(2).getHitCount()); + EXPECT_EQ(2u, p->getFieldInfo(3).getHitOffset()); + EXPECT_EQ(1u, p->getFieldInfo(3).getHitCount()); + EXPECT_TRUE(p->evaluate()); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/query/streaming/same_element_query_node_test.cpp b/searchlib/src/tests/query/streaming/same_element_query_node_test.cpp new file mode 100644 index 00000000000..ece6dc551b2 --- /dev/null +++ b/searchlib/src/tests/query/streaming/same_element_query_node_test.cpp @@ -0,0 +1,135 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchlib/query/streaming/same_element_query_node.h> +#include <vespa/searchlib/query/streaming/queryterm.h> +#include <vespa/searchlib/query/tree/querybuilder.h> +#include <vespa/searchlib/query/tree/simplequery.h> +#include <vespa/searchlib/query/tree/stackdumpcreator.h> +#include <vespa/vespalib/gtest/gtest.h> + +using search::query::QueryBuilder; +using search::query::Node; +using search::query::SimpleQueryNodeTypes; +using search::query::StackDumpCreator; +using search::query::Weight; +using search::streaming::HitList; +using search::streaming::Query; +using search::streaming::QueryNode; +using search::streaming::QueryNodeResultFactory; +using search::streaming::QueryTerm; +using search::streaming::QueryTermList; +using search::streaming::SameElementQueryNode; + +namespace { + +class AllowRewrite : public QueryNodeResultFactory +{ +public: + explicit AllowRewrite(vespalib::stringref index) noexcept : _allowedIndex(index) {} + bool allow_float_terms_rewrite(vespalib::stringref index) const noexcept override { return index == _allowedIndex; } +private: + vespalib::string _allowedIndex; +}; + +} + +TEST(SameElementQueryNodeTest, a_unhandled_sameElement_stack) +{ + const char * stack = "\022\002\026xyz_abcdefghij_xyzxyzxQ\001\vxxxxxx_name\034xxxxxx_xxxx_xxxxxxx_xxxxxxxxE\002\005delta\b<0.00393"; + vespalib::stringref stackDump(stack); + EXPECT_EQ(85u, stackDump.size()); + AllowRewrite empty(""); + const Query q(empty, stackDump); + EXPECT_TRUE(q.valid()); + const QueryNode & root = q.getRoot(); + auto sameElement = dynamic_cast<const SameElementQueryNode *>(&root); + EXPECT_TRUE(sameElement != nullptr); + EXPECT_EQ(2u, sameElement->size()); + EXPECT_EQ("xyz_abcdefghij_xyzxyzx", sameElement->getIndex()); + auto term0 = dynamic_cast<const QueryTerm *>((*sameElement)[0].get()); + EXPECT_TRUE(term0 != nullptr); + auto term1 = dynamic_cast<const QueryTerm *>((*sameElement)[1].get()); + EXPECT_TRUE(term1 != nullptr); +} + +namespace { + void verifyQueryTermNode(const vespalib::string & index, const QueryNode *node) { + EXPECT_TRUE(dynamic_cast<const QueryTerm *>(node) != nullptr); + EXPECT_EQ(index, node->getIndex()); + } +} + +TEST(SameElementQueryNodeTest, test_same_element_evaluate) +{ + QueryBuilder<SimpleQueryNodeTypes> builder; + builder.addSameElement(3, "field", 0, Weight(0)); + { + builder.addStringTerm("a", "f1", 0, Weight(0)); + builder.addStringTerm("b", "f2", 1, Weight(0)); + builder.addStringTerm("c", "f3", 2, Weight(0)); + } + Node::UP node = builder.build(); + vespalib::string stackDump = StackDumpCreator::create(*node); + QueryNodeResultFactory empty; + Query q(empty, stackDump); + auto * sameElem = dynamic_cast<SameElementQueryNode *>(&q.getRoot()); + EXPECT_TRUE(sameElem != nullptr); + EXPECT_EQ("field", sameElem->getIndex()); + EXPECT_EQ(3u, sameElem->size()); + verifyQueryTermNode("field.f1", (*sameElem)[0].get()); + verifyQueryTermNode("field.f2", (*sameElem)[1].get()); + verifyQueryTermNode("field.f3", (*sameElem)[2].get()); + + QueryTermList terms; + q.getLeaves(terms); + EXPECT_EQ(3u, terms.size()); + for (QueryTerm * qt : terms) { + qt->resizeFieldId(3); + } + + // field 0 + terms[0]->add(0, 0, 10, 1); + terms[0]->add(0, 1, 20, 2); + terms[0]->add(0, 2, 30, 3); + terms[0]->add(0, 3, 40, 4); + terms[0]->add(0, 4, 50, 5); + terms[0]->add(0, 5, 60, 6); + + terms[1]->add(1, 0, 70, 7); + terms[1]->add(1, 1, 80, 8); + terms[1]->add(1, 2, 90, 9); + terms[1]->add(1, 4, 100, 10); + terms[1]->add(1, 5, 110, 11); + terms[1]->add(1, 6, 120, 12); + + terms[2]->add(2, 0, 130, 13); + terms[2]->add(2, 2, 140, 14); + terms[2]->add(2, 4, 150, 15); + terms[2]->add(2, 5, 160, 16); + terms[2]->add(2, 6, 170, 17); + HitList hits; + sameElem->evaluateHits(hits); + EXPECT_EQ(4u, hits.size()); + EXPECT_EQ(2u, hits[0].field_id()); + EXPECT_EQ(0u, hits[0].element_id()); + EXPECT_EQ(130, hits[0].element_weight()); + EXPECT_EQ(0u, hits[0].position()); + + EXPECT_EQ(2u, hits[1].field_id()); + EXPECT_EQ(2u, hits[1].element_id()); + EXPECT_EQ(140, hits[1].element_weight()); + EXPECT_EQ(0u, hits[1].position()); + + EXPECT_EQ(2u, hits[2].field_id()); + EXPECT_EQ(4u, hits[2].element_id()); + EXPECT_EQ(150, hits[2].element_weight()); + EXPECT_EQ(0u, hits[2].position()); + + EXPECT_EQ(2u, hits[3].field_id()); + EXPECT_EQ(5u, hits[3].element_id()); + EXPECT_EQ(160, hits[3].element_weight()); + EXPECT_EQ(0u, hits[3].position()); + EXPECT_TRUE(sameElem->evaluate()); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index fe6149e6fba..d2be1d453a2 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -4,6 +4,7 @@ #include <vespa/searchlib/fef/matchdata.h> #include <vespa/searchlib/query/streaming/dot_product_term.h> #include <vespa/searchlib/query/streaming/in_term.h> +#include <vespa/searchlib/query/streaming/phrase_query_node.h> #include <vespa/searchlib/query/streaming/query.h> #include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h> #include <vespa/searchlib/query/streaming/wand_term.h> @@ -23,10 +24,11 @@ using TermType = QueryTerm::Type; using search::fef::SimpleTermData; using search::fef::MatchData; -void assertHit(const Hit & h, size_t expWordpos, size_t expContext, int32_t weight) { - EXPECT_EQ(h.wordpos(), expWordpos); - EXPECT_EQ(h.context(), expContext); - EXPECT_EQ(h.weight(), weight); +void assertHit(const Hit & h, uint32_t exp_field_id, uint32_t exp_element_id, int32_t exp_element_weight, size_t exp_position) { + EXPECT_EQ(h.field_id(), exp_field_id); + EXPECT_EQ(h.element_id(), exp_element_id); + EXPECT_EQ(h.element_weight(), exp_element_weight); + EXPECT_EQ(h.position(), exp_position); } @@ -427,87 +429,19 @@ TEST(StreamingQueryTest, test_get_query_parts) } } -TEST(StreamingQueryTest, test_phrase_evaluate) +TEST(StreamingQueryTest, test_hit) { - QueryBuilder<SimpleQueryNodeTypes> builder; - builder.addPhrase(3, "", 0, Weight(0)); - { - builder.addStringTerm("a", "", 0, Weight(0)); - builder.addStringTerm("b", "", 0, Weight(0)); - builder.addStringTerm("c", "", 0, Weight(0)); - } - Node::UP node = builder.build(); - vespalib::string stackDump = StackDumpCreator::create(*node); - QueryNodeResultFactory empty; - Query q(empty, stackDump); - QueryNodeRefList phrases; - q.getPhrases(phrases); - QueryTermList terms; - q.getLeaves(terms); - for (QueryTerm * qt : terms) { - qt->resizeFieldId(1); - } + // field id + assertHit(Hit( 1, 0, 1, 0), 1, 0, 1, 0); + assertHit(Hit(255, 0, 1, 0), 255, 0, 1, 0); + assertHit(Hit(256, 0, 1, 0), 256, 0, 1, 0); - // field 0 - terms[0]->add(0, 0, 0, 1); - terms[1]->add(1, 0, 0, 1); - terms[2]->add(2, 0, 0, 1); - terms[0]->add(7, 0, 0, 1); - terms[1]->add(8, 0, 1, 1); - terms[2]->add(9, 0, 0, 1); - // field 1 - terms[0]->add(4, 1, 0, 1); - terms[1]->add(5, 1, 0, 1); - terms[2]->add(6, 1, 0, 1); - // field 2 (not complete match) - terms[0]->add(1, 2, 0, 1); - terms[1]->add(2, 2, 0, 1); - terms[2]->add(4, 2, 0, 1); - // field 3 - terms[0]->add(0, 3, 0, 1); - terms[1]->add(1, 3, 0, 1); - terms[2]->add(2, 3, 0, 1); - // field 4 (not complete match) - terms[0]->add(1, 4, 0, 1); - terms[1]->add(2, 4, 0, 1); - // field 5 (not complete match) - terms[0]->add(2, 5, 0, 1); - terms[1]->add(1, 5, 0, 1); - terms[2]->add(0, 5, 0, 1); - HitList hits; - auto * p = static_cast<PhraseQueryNode *>(phrases[0]); - p->evaluateHits(hits); - ASSERT_EQ(3u, hits.size()); - EXPECT_EQ(hits[0].wordpos(), 2u); - EXPECT_EQ(hits[0].context(), 0u); - EXPECT_EQ(hits[1].wordpos(), 6u); - EXPECT_EQ(hits[1].context(), 1u); - EXPECT_EQ(hits[2].wordpos(), 2u); - EXPECT_EQ(hits[2].context(), 3u); - ASSERT_EQ(4u, p->getFieldInfoSize()); - EXPECT_EQ(p->getFieldInfo(0).getHitOffset(), 0u); - EXPECT_EQ(p->getFieldInfo(0).getHitCount(), 1u); - EXPECT_EQ(p->getFieldInfo(1).getHitOffset(), 1u); - EXPECT_EQ(p->getFieldInfo(1).getHitCount(), 1u); - EXPECT_EQ(p->getFieldInfo(2).getHitOffset(), 0u); // invalid, but will never be used - EXPECT_EQ(p->getFieldInfo(2).getHitCount(), 0u); - EXPECT_EQ(p->getFieldInfo(3).getHitOffset(), 2u); - EXPECT_EQ(p->getFieldInfo(3).getHitCount(), 1u); - EXPECT_TRUE(p->evaluate()); -} + // positions + assertHit(Hit(0, 0, 0, 0), 0, 0, 0, 0); + assertHit(Hit(0, 0, 1, 256), 0, 0, 1, 256); + assertHit(Hit(0, 0, -1, 16777215), 0, 0, -1, 16777215); + assertHit(Hit(0, 0, 1, 16777216), 0, 0, 1, 16777216); -TEST(StreamingQueryTest, test_hit) -{ - // positions (0 - (2^24-1)) - assertHit(Hit(0, 0, 0, 0), 0, 0, 0); - assertHit(Hit(256, 0, 0, 1), 256, 0, 1); - assertHit(Hit(16777215, 0, 0, -1), 16777215, 0, -1); - assertHit(Hit(16777216, 0, 0, 1), 0, 1, 1); // overflow - - // contexts (0 - 255) - assertHit(Hit(0, 1, 0, 1), 0, 1, 1); - assertHit(Hit(0, 255, 0, 1), 0, 255, 1); - assertHit(Hit(0, 256, 0, 1), 0, 0, 1); // overflow } void assertInt8Range(const std::string &term, bool expAdjusted, int64_t expLow, int64_t expHigh) { @@ -769,105 +703,6 @@ TEST(StreamingQueryTest, require_that_we_do_not_break_the_stack_on_bad_query) EXPECT_FALSE(term.isValid()); } -TEST(StreamingQueryTest, a_unhandled_sameElement_stack) -{ - const char * stack = "\022\002\026xyz_abcdefghij_xyzxyzxQ\001\vxxxxxx_name\034xxxxxx_xxxx_xxxxxxx_xxxxxxxxE\002\005delta\b<0.00393"; - vespalib::stringref stackDump(stack); - EXPECT_EQ(85u, stackDump.size()); - AllowRewrite empty(""); - const Query q(empty, stackDump); - EXPECT_TRUE(q.valid()); - const QueryNode & root = q.getRoot(); - auto sameElement = dynamic_cast<const SameElementQueryNode *>(&root); - EXPECT_TRUE(sameElement != nullptr); - EXPECT_EQ(2u, sameElement->size()); - EXPECT_EQ("xyz_abcdefghij_xyzxyzx", sameElement->getIndex()); - auto term0 = dynamic_cast<const QueryTerm *>((*sameElement)[0].get()); - EXPECT_TRUE(term0 != nullptr); - auto term1 = dynamic_cast<const QueryTerm *>((*sameElement)[1].get()); - EXPECT_TRUE(term1 != nullptr); -} - -namespace { - void verifyQueryTermNode(const vespalib::string & index, const QueryNode *node) { - EXPECT_TRUE(dynamic_cast<const QueryTerm *>(node) != nullptr); - EXPECT_EQ(index, node->getIndex()); - } -} - -TEST(StreamingQueryTest, test_same_element_evaluate) -{ - QueryBuilder<SimpleQueryNodeTypes> builder; - builder.addSameElement(3, "field", 0, Weight(0)); - { - builder.addStringTerm("a", "f1", 0, Weight(0)); - builder.addStringTerm("b", "f2", 1, Weight(0)); - builder.addStringTerm("c", "f3", 2, Weight(0)); - } - Node::UP node = builder.build(); - vespalib::string stackDump = StackDumpCreator::create(*node); - QueryNodeResultFactory empty; - Query q(empty, stackDump); - auto * sameElem = dynamic_cast<SameElementQueryNode *>(&q.getRoot()); - EXPECT_TRUE(sameElem != nullptr); - EXPECT_EQ("field", sameElem->getIndex()); - EXPECT_EQ(3u, sameElem->size()); - verifyQueryTermNode("field.f1", (*sameElem)[0].get()); - verifyQueryTermNode("field.f2", (*sameElem)[1].get()); - verifyQueryTermNode("field.f3", (*sameElem)[2].get()); - - QueryTermList terms; - q.getLeaves(terms); - EXPECT_EQ(3u, terms.size()); - for (QueryTerm * qt : terms) { - qt->resizeFieldId(3); - } - - // field 0 - terms[0]->add(1, 0, 0, 10); - terms[0]->add(2, 0, 1, 20); - terms[0]->add(3, 0, 2, 30); - terms[0]->add(4, 0, 3, 40); - terms[0]->add(5, 0, 4, 50); - terms[0]->add(6, 0, 5, 60); - - terms[1]->add(7, 1, 0, 70); - terms[1]->add(8, 1, 1, 80); - terms[1]->add(9, 1, 2, 90); - terms[1]->add(10, 1, 4, 100); - terms[1]->add(11, 1, 5, 110); - terms[1]->add(12, 1, 6, 120); - - terms[2]->add(13, 2, 0, 130); - terms[2]->add(14, 2, 2, 140); - terms[2]->add(15, 2, 4, 150); - terms[2]->add(16, 2, 5, 160); - terms[2]->add(17, 2, 6, 170); - HitList hits; - sameElem->evaluateHits(hits); - EXPECT_EQ(4u, hits.size()); - EXPECT_EQ(0u, hits[0].wordpos()); - EXPECT_EQ(2u, hits[0].context()); - EXPECT_EQ(0u, hits[0].elemId()); - EXPECT_EQ(130, hits[0].weight()); - - EXPECT_EQ(0u, hits[1].wordpos()); - EXPECT_EQ(2u, hits[1].context()); - EXPECT_EQ(2u, hits[1].elemId()); - EXPECT_EQ(140, hits[1].weight()); - - EXPECT_EQ(0u, hits[2].wordpos()); - EXPECT_EQ(2u, hits[2].context()); - EXPECT_EQ(4u, hits[2].elemId()); - EXPECT_EQ(150, hits[2].weight()); - - EXPECT_EQ(0u, hits[3].wordpos()); - EXPECT_EQ(2u, hits[3].context()); - EXPECT_EQ(5u, hits[3].elemId()); - EXPECT_EQ(160, hits[3].weight()); - EXPECT_TRUE(sameElem->evaluate()); -} - TEST(StreamingQueryTest, test_nearest_neighbor_query_node) { QueryBuilder<SimpleQueryNodeTypes> builder; @@ -917,8 +752,8 @@ TEST(StreamingQueryTest, test_in_term) td.lookupField(12)->setHandle(1); EXPECT_FALSE(term.evaluate()); auto& q = *term.get_terms().front(); - q.add(0, 11, 0, 1); - q.add(0, 12, 0, 1); + q.add(11, 0, 1, 0); + q.add(12, 0, 1, 0); EXPECT_TRUE(term.evaluate()); MatchData md(MatchData::params().numTermFields(2)); term.unpack_match_data(23, td, md); @@ -944,11 +779,11 @@ TEST(StreamingQueryTest, dot_product_term) td.lookupField(12)->setHandle(1); EXPECT_FALSE(term.evaluate()); auto& q0 = *term.get_terms()[0]; - q0.add(0, 11, 0, -13); - q0.add(0, 12, 0, -17); + q0.add(11, 0, -13, 0); + q0.add(12, 0, -17, 0); auto& q1 = *term.get_terms()[1]; - q1.add(0, 11, 0, 4); - q1.add(0, 12, 0, 9); + q1.add(11, 0, 4, 0); + q1.add(12, 0, 9, 0); EXPECT_TRUE(term.evaluate()); MatchData md(MatchData::params().numTermFields(2)); term.unpack_match_data(23, td, md); @@ -989,11 +824,11 @@ check_wand_term(double limit, const vespalib::string& label) td.lookupField(12)->setHandle(1); EXPECT_FALSE(term.evaluate()); auto& q0 = *term.get_terms()[0]; - q0.add(0, 11, 0, 17); - q0.add(0, 12, 0, 13); + q0.add(11, 0, 17, 0); + q0.add(12, 0, 13, 0); auto& q1 = *term.get_terms()[1]; - q1.add(0, 11, 0, 9); - q1.add(0, 12, 0, 4); + q1.add(11, 0, 9, 0); + q1.add(12, 0, 4, 0); EXPECT_EQ(limit < exp_wand_score_field_11, term.evaluate()); MatchData md(MatchData::params().numTermFields(2)); term.unpack_match_data(23, td, md); @@ -1043,11 +878,11 @@ TEST(StreamingQueryTest, weighted_set_term) td.lookupField(12)->setHandle(1); EXPECT_FALSE(term.evaluate()); auto& q0 = *term.get_terms()[0]; - q0.add(0, 11, 0, 10); - q0.add(0, 12, 0, 10); + q0.add(11, 0, 10, 0); + q0.add(12, 0, 10, 0); auto& q1 = *term.get_terms()[1]; - q1.add(0, 11, 0, 10); - q1.add(0, 12, 0, 10); + q1.add(11, 0, 10, 0); + q1.add(12, 0, 10, 0); EXPECT_TRUE(term.evaluate()); MatchData md(MatchData::params().numTermFields(2)); term.unpack_match_data(23, td, md); diff --git a/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp b/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp index bbd2744119a..90452f1d12b 100644 --- a/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp +++ b/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp @@ -23,12 +23,15 @@ class MyOr : public IntermediateBlueprint { private: public: - double calculate_cost() const final { - return OrFlow::cost_of(get_children()); - } double calculate_relative_estimate() const final { return OrFlow::estimate_of(get_children()); } + double calculate_cost() const final { + return OrFlow::cost_of(get_children(), false); + } + double calculate_strict_cost() const final { + return OrFlow::cost_of(get_children(), true); + } HitEstimate combine(const std::vector<HitEstimate> &data) const override { return max(data); } @@ -37,7 +40,7 @@ public: return mixChildrenFields(); } - void sort(Children &children, bool) const override { + void sort(Children &children, bool, bool) const override { std::sort(children.begin(), children.end(), TieredGreaterEstimate()); } @@ -446,7 +449,7 @@ TEST_F("testChildAndNotCollapsing", Fixture) ); TEST_DO(f.check_not_equal(*sorted, *unsorted)); unsorted->setDocIdLimit(1000); - unsorted = Blueprint::optimize(std::move(unsorted), true); + unsorted = Blueprint::optimize_and_sort(std::move(unsorted), true, true); TEST_DO(f.check_equal(*sorted, *unsorted)); } @@ -486,7 +489,7 @@ TEST_F("testChildAndCollapsing", Fixture) TEST_DO(f.check_not_equal(*sorted, *unsorted)); unsorted->setDocIdLimit(1000); - unsorted = Blueprint::optimize(std::move(unsorted), true); + unsorted = Blueprint::optimize_and_sort(std::move(unsorted), true, true); TEST_DO(f.check_equal(*sorted, *unsorted)); } @@ -525,7 +528,10 @@ TEST_F("testChildOrCollapsing", Fixture) ); TEST_DO(f.check_not_equal(*sorted, *unsorted)); unsorted->setDocIdLimit(1000); - unsorted = Blueprint::optimize(std::move(unsorted), true); + // we sort non-strict here since the default costs of 1/est for + // non-strict/strict leaf iterators makes the order of iterators + // under a strict OR irrelevant. + unsorted = Blueprint::optimize_and_sort(std::move(unsorted), false, true); TEST_DO(f.check_equal(*sorted, *unsorted)); } @@ -569,7 +575,7 @@ TEST_F("testChildSorting", Fixture) TEST_DO(f.check_not_equal(*sorted, *unsorted)); unsorted->setDocIdLimit(1000); - unsorted = Blueprint::optimize(std::move(unsorted), true); + unsorted = Blueprint::optimize_and_sort(std::move(unsorted), true, true); TEST_DO(f.check_equal(*sorted, *unsorted)); } @@ -650,12 +656,13 @@ getExpectedBlueprint() " estimate: HitEstimate {\n" " empty: false\n" " estHits: 9\n" - " relative_estimate: 0.5\n" " cost_tier: 1\n" " tree_size: 2\n" " allow_termwise_eval: false\n" " }\n" - " cost: 1\n" + " relative_estimate: 0\n" + " cost: 0\n" + " strict_cost: 0\n" " sourceId: 4294967295\n" " docid_limit: 0\n" " children: std::vector {\n" @@ -671,12 +678,13 @@ getExpectedBlueprint() " estimate: HitEstimate {\n" " empty: false\n" " estHits: 9\n" - " relative_estimate: 0.5\n" " cost_tier: 1\n" " tree_size: 1\n" " allow_termwise_eval: true\n" " }\n" - " cost: 1\n" + " relative_estimate: 0\n" + " cost: 0\n" + " strict_cost: 0\n" " sourceId: 4294967295\n" " docid_limit: 0\n" " }\n" @@ -702,12 +710,13 @@ getExpectedSlimeBlueprint() { " '[type]': 'HitEstimate'," " empty: false," " estHits: 9," - " relative_estimate: 0.5," " cost_tier: 1," " tree_size: 2," " allow_termwise_eval: false" " }," - " cost: 1.0," + " relative_estimate: 0.0," + " cost: 0.0," + " strict_cost: 0.0," " sourceId: 4294967295," " docid_limit: 0," " children: {" @@ -728,12 +737,13 @@ getExpectedSlimeBlueprint() { " '[type]': 'HitEstimate'," " empty: false," " estHits: 9," - " relative_estimate: 0.5," " cost_tier: 1," " tree_size: 1," " allow_termwise_eval: true" " }," - " cost: 1.0," + " relative_estimate: 0.0," + " cost: 0.0," + " strict_cost: 0.0," " sourceId: 4294967295," " docid_limit: 0" " }" @@ -786,9 +796,9 @@ TEST("requireThatDocIdLimitInjectionWorks") } TEST("Control object sizes") { - EXPECT_EQUAL(40u, sizeof(Blueprint::State)); - EXPECT_EQUAL(40u, sizeof(Blueprint)); - EXPECT_EQUAL(80u, sizeof(LeafBlueprint)); + EXPECT_EQUAL(32u, sizeof(Blueprint::State)); + EXPECT_EQUAL(56u, sizeof(Blueprint)); + EXPECT_EQUAL(96u, sizeof(LeafBlueprint)); } TEST_MAIN() { diff --git a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp index 856ac2391f8..2cf523b508b 100644 --- a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp +++ b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp @@ -77,7 +77,7 @@ void check_sort_order(IntermediateBlueprint &self, BlueprintVector children, std unordered.push_back(child.get()); } // TODO: sort by cost (requires both setDocIdLimit and optimize to be called) - self.sort(children, false); + self.sort(children, true, false); for (size_t i = 0; i < children.size(); ++i) { EXPECT_EQUAL(children[i].get(), unordered[order[i]]); } @@ -130,8 +130,8 @@ TEST("test AndNot Blueprint") { } template <typename BP> -void optimize(std::unique_ptr<BP> &ref) { - auto optimized = Blueprint::optimize(std::move(ref), true); +void optimize(std::unique_ptr<BP> &ref, bool strict) { + auto optimized = Blueprint::optimize_and_sort(std::move(ref), strict, true); ref.reset(dynamic_cast<BP*>(optimized.get())); ASSERT_TRUE(ref); optimized.release(); @@ -144,7 +144,7 @@ TEST("test And propagates updated histestimate") { bp->addChild(ap(MyLeafSpec(200).create<RememberExecuteInfo>()->setSourceId(2))); bp->addChild(ap(MyLeafSpec(2000).create<RememberExecuteInfo>()->setSourceId(2))); bp->setDocIdLimit(5000); - optimize(bp); + optimize(bp, true); bp->fetchPostings(ExecuteInfo::TRUE); EXPECT_EQUAL(3u, bp->childCnt()); for (uint32_t i = 0; i < bp->childCnt(); i++) { @@ -164,7 +164,10 @@ TEST("test Or propagates updated histestimate") { bp->addChild(ap(MyLeafSpec(800).create<RememberExecuteInfo>()->setSourceId(2))); bp->addChild(ap(MyLeafSpec(20).create<RememberExecuteInfo>()->setSourceId(2))); bp->setDocIdLimit(5000); - optimize(bp); + // sort OR as non-strict to get expected order. With strict OR, + // the order would be irrelevant since we use the relative + // estimate as strict_cost for leafs. + optimize(bp, false); bp->fetchPostings(ExecuteInfo::TRUE); EXPECT_EQUAL(4u, bp->childCnt()); for (uint32_t i = 0; i < bp->childCnt(); i++) { @@ -519,16 +522,18 @@ void compare(const Blueprint &bp1, const Blueprint &bp2, bool expect_eq) { auto cmp_hook = [expect_eq](const auto &path, const auto &a, const auto &b) { if (!path.empty() && std::holds_alternative<vespalib::stringref>(path.back())) { vespalib::stringref field = std::get<vespalib::stringref>(path.back()); - if (field == "cost") { + // ignore these fields to enable comparing optimized with unoptimized trees + if (field == "relative_estimate" || field == "cost" || field == "strict_cost") { + auto check_value = [&](double value){ + if ((value > 0.0 && value < 1e-6) || (value > 0.0 && value < 1e-6)) { + fprintf(stderr, " small value at %s: %g\n", path_to_str(path).c_str(), + value); + } + }; + check_value(a.asDouble()); + check_value(b.asDouble()); return true; } - if (field == "relative_estimate") { - double a_val = a.asDouble(); - double b_val = b.asDouble(); - if (a_val != 0.0 && b_val != 0.0 && vespalib::approx_equal(a_val, b_val)) { - return true; - } - } } if (expect_eq) { fprintf(stderr, " mismatch at %s: %s vs %s\n", path_to_str(path).c_str(), @@ -548,13 +553,13 @@ void compare(const Blueprint &bp1, const Blueprint &bp2, bool expect_eq) { } void -optimize_and_compare(Blueprint::UP top, Blueprint::UP expect, bool sort_by_cost = true) { +optimize_and_compare(Blueprint::UP top, Blueprint::UP expect, bool strict = true, bool sort_by_cost = true) { top->setDocIdLimit(1000); expect->setDocIdLimit(1000); TEST_DO(compare(*top, *expect, false)); - top = Blueprint::optimize(std::move(top), sort_by_cost); + top = Blueprint::optimize_and_sort(std::move(top), strict, sort_by_cost); TEST_DO(compare(*top, *expect, true)); - expect = Blueprint::optimize(std::move(expect), sort_by_cost); + expect = Blueprint::optimize_and_sort(std::move(expect), strict, sort_by_cost); TEST_DO(compare(*expect, *top, true)); } @@ -614,7 +619,8 @@ TEST_F("test SourceBlender below OR partial optimization", SourceBlenderTestFixt expect->addChild(addLeafsWithSourceId(std::make_unique<SourceBlenderBlueprint>(f.selector_2), {{10, 1}, {20, 2}})); addLeafs(*expect, {3, 2, 1}); - optimize_and_compare(std::move(top), std::move(expect)); + // NOTE: use non-strict cost based sorting for expected order + optimize_and_compare(std::move(top), std::move(expect), false); } TEST_F("test OR replaced by source blender after full optimization", SourceBlenderTestFixture) { @@ -626,7 +632,8 @@ TEST_F("test OR replaced by source blender after full optimization", SourceBlend expect->addChild(addLeafsWithSourceId(2, std::make_unique<OrBlueprint>(), {{2000, 2}, {200, 2}, {20, 2}})); expect->addChild(addLeafsWithSourceId(1, std::make_unique<OrBlueprint>(), {{1000, 1}, {100, 1}, {10, 1}})); - optimize_and_compare(std::move(top), std::move(expect)); + // NOTE: use non-strict cost based sorting for expected order + optimize_and_compare(std::move(top), std::move(expect), false); } TEST_F("test SourceBlender below AND_NOT optimization", SourceBlenderTestFixture) { @@ -681,11 +688,11 @@ TEST("test empty root node optimization and safeness") { //------------------------------------------------------------------------- auto expect_up = std::make_unique<EmptyBlueprint>(); - EXPECT_EQUAL(expect_up->asString(), Blueprint::optimize(std::move(top1), true)->asString()); - EXPECT_EQUAL(expect_up->asString(), Blueprint::optimize(std::move(top2), true)->asString()); - EXPECT_EQUAL(expect_up->asString(), Blueprint::optimize(std::move(top3), true)->asString()); - EXPECT_EQUAL(expect_up->asString(), Blueprint::optimize(std::move(top4), true)->asString()); - EXPECT_EQUAL(expect_up->asString(), Blueprint::optimize(std::move(top5), true)->asString()); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top1), true, true), true); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top2), true, true), true); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top3), true, true), true); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top4), true, true), true); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top5), true, true), true); } TEST("and with one empty child is optimized away") { @@ -693,11 +700,11 @@ TEST("and with one empty child is optimized away") { Blueprint::UP top = ap((new SourceBlenderBlueprint(*selector))-> addChild(ap(MyLeafSpec(10).create())). addChild(addLeafs(std::make_unique<AndBlueprint>(), {{0, true}, 10, 20}))); - top = Blueprint::optimize(std::move(top), true); + top = Blueprint::optimize_and_sort(std::move(top), true, true); Blueprint::UP expect_up(ap((new SourceBlenderBlueprint(*selector))-> addChild(ap(MyLeafSpec(10).create())). addChild(std::make_unique<EmptyBlueprint>()))); - EXPECT_EQUAL(expect_up->asString(), top->asString()); + compare(*expect_up, *top, true); } struct make { @@ -731,7 +738,7 @@ struct make { return std::move(*this); } make &&leaf(uint32_t estimate) && { - std::unique_ptr<LeafBlueprint> bp(MyLeafSpec(estimate).create()); + std::unique_ptr<MyLeaf> bp(MyLeafSpec(estimate).create()); if (cost_tag != invalid_cost) { bp->set_cost(cost_tag); cost_tag = invalid_cost; @@ -763,7 +770,8 @@ TEST("AND AND collapsing") { TEST("OR OR collapsing") { Blueprint::UP top = make::OR().leafs({1,3,5}).add(make::OR().leafs({2,4})); Blueprint::UP expect = make::OR().leafs({5,4,3,2,1}); - optimize_and_compare(std::move(top), std::move(expect)); + // NOTE: use non-strict cost based sorting for expected order + optimize_and_compare(std::move(top), std::move(expect), false); } TEST("AND_NOT AND_NOT collapsing") { @@ -841,7 +849,8 @@ TEST("test single child optimization") { TEST("test empty OR child optimization") { Blueprint::UP top = addLeafs(std::make_unique<OrBlueprint>(), {{0, true}, 20, {0, true}, 10, {0, true}, 0, 30, {0, true}}); Blueprint::UP expect = addLeafs(std::make_unique<OrBlueprint>(), {30, 20, 10, 0}); - optimize_and_compare(std::move(top), std::move(expect)); + // NOTE: use non-strict cost based sorting for expected order + optimize_and_compare(std::move(top), std::move(expect), false); } TEST("test empty AND_NOT child optimization") { @@ -868,10 +877,10 @@ TEST("require that replaced blueprints retain source id") { addChild(ap(MyLeafSpec(30).create()->setSourceId(55))))); Blueprint::UP expect2_up(ap(MyLeafSpec(30).create()->setSourceId(42))); //------------------------------------------------------------------------- - top1_up = Blueprint::optimize(std::move(top1_up), true); - top2_up = Blueprint::optimize(std::move(top2_up), true); - EXPECT_EQUAL(expect1_up->asString(), top1_up->asString()); - EXPECT_EQUAL(expect2_up->asString(), top2_up->asString()); + top1_up = Blueprint::optimize_and_sort(std::move(top1_up), true, true); + top2_up = Blueprint::optimize_and_sort(std::move(top2_up), true, true); + compare(*expect1_up, *top1_up, true); + compare(*expect2_up, *top2_up, true); EXPECT_EQUAL(13u, top1_up->getSourceId()); EXPECT_EQUAL(42u, top2_up->getSourceId()); } @@ -1181,45 +1190,25 @@ TEST("require_that_unpack_optimization_is_not_overruled_by_equiv") { } } -TEST("require that children of near are not optimized") { - auto top_up = ap((new NearBlueprint(10))-> - addChild(addLeafs(std::make_unique<OrBlueprint>(), {20, {0, true}})). - addChild(addLeafs(std::make_unique<OrBlueprint>(), {{0, true}, 30}))); - auto expect_up = ap((new NearBlueprint(10))-> - addChild(addLeafs(std::make_unique<OrBlueprint>(), {20, {0, true}})). - addChild(addLeafs(std::make_unique<OrBlueprint>(), {{0, true}, 30}))); - top_up = Blueprint::optimize(std::move(top_up), true); - TEST_DO(compare(*top_up, *expect_up, true)); -} - -TEST("require that children of onear are not optimized") { - auto top_up = ap((new ONearBlueprint(10))-> - addChild(addLeafs(std::make_unique<OrBlueprint>(), {20, {0, true}})). - addChild(addLeafs(std::make_unique<OrBlueprint>(), {{0, true}, 30}))); - auto expect_up = ap((new ONearBlueprint(10))-> - addChild(addLeafs(std::make_unique<OrBlueprint>(), {20, {0, true}})). - addChild(addLeafs(std::make_unique<OrBlueprint>(), {{0, true}, 30}))); - top_up = Blueprint::optimize(std::move(top_up), true); - TEST_DO(compare(*top_up, *expect_up, true)); -} - TEST("require that ANDNOT without children is optimized to empty search") { Blueprint::UP top_up = std::make_unique<AndNotBlueprint>(); auto expect_up = std::make_unique<EmptyBlueprint>(); - top_up = Blueprint::optimize(std::move(top_up), true); - EXPECT_EQUAL(expect_up->asString(), top_up->asString()); + top_up = Blueprint::optimize_and_sort(std::move(top_up), true, true); + compare(*expect_up, *top_up, true); } TEST("require that highest cost tier sorts last for OR") { Blueprint::UP top = addLeafsWithCostTier(std::make_unique<OrBlueprint>(), {{50, 1}, {30, 3}, {20, 2}, {10, 1}}); Blueprint::UP expect = addLeafsWithCostTier(std::make_unique<OrBlueprint>(), {{50, 1}, {10, 1}, {20, 2}, {30, 3}}); - optimize_and_compare(std::move(top), std::move(expect), false); + // cost-based sorting would ignore cost tier + optimize_and_compare(std::move(top), std::move(expect), true, false); } TEST("require that highest cost tier sorts last for AND") { Blueprint::UP top = addLeafsWithCostTier(std::make_unique<AndBlueprint>(), {{10, 1}, {20, 3}, {30, 2}, {50, 1}}); Blueprint::UP expect = addLeafsWithCostTier(std::make_unique<AndBlueprint>(), {{10, 1}, {50, 1}, {30, 2}, {20, 3}}); - optimize_and_compare(std::move(top), std::move(expect), false); + // cost-based sorting would ignore cost tier + optimize_and_compare(std::move(top), std::move(expect), true, false); } template<typename BP> @@ -1292,6 +1281,7 @@ void verify_relative_estimate(make &&mk, double expect) { EXPECT_EQUAL(mk.making->estimate(), 0.0); Blueprint::UP bp = std::move(mk).leafs({200,300,950}); bp->setDocIdLimit(1000); + bp = Blueprint::optimize(std::move(bp)); EXPECT_EQUAL(bp->estimate(), expect); } @@ -1329,15 +1319,17 @@ TEST("relative estimate for WEAKAND") { verify_relative_estimate(make::WEAKAND(50), 0.05); } -void verify_cost(make &&mk, double expect) { - EXPECT_EQUAL(mk.making->cost(), 1.0); +void verify_cost(make &&mk, double expect, double expect_strict) { + EXPECT_EQUAL(mk.making->cost(), 0.0); + EXPECT_EQUAL(mk.making->strict_cost(), 0.0); Blueprint::UP bp = std::move(mk) - .cost(1.1).leaf(200) - .cost(1.2).leaf(300) - .cost(1.3).leaf(500); + .cost(1.1).leaf(200) // strict_cost: 0.2*1.1 + .cost(1.2).leaf(300) // strict_cost: 0.3*1.2 + .cost(1.3).leaf(950); // rel_est: 0.5, strict_cost: 1.3 bp->setDocIdLimit(1000); - bp = Blueprint::optimize(std::move(bp), true); + bp = Blueprint::optimize(std::move(bp)); EXPECT_EQUAL(bp->cost(), expect); + EXPECT_EQUAL(bp->strict_cost(), expect_strict); } double calc_cost(std::vector<std::pair<double,double>> list) { @@ -1351,36 +1343,48 @@ double calc_cost(std::vector<std::pair<double,double>> list) { } TEST("cost for OR") { - verify_cost(make::OR(), calc_cost({{1.3, 0.5},{1.2, 0.7},{1.1, 0.8}})); + verify_cost(make::OR(), + calc_cost({{1.3, 0.5},{1.2, 0.7},{1.1, 0.8}}), + calc_cost({{0.2*1.1, 0.8},{0.3*1.2, 0.7},{1.3, 0.5}})); } TEST("cost for AND") { - verify_cost(make::AND(), calc_cost({{1.1, 0.2},{1.2, 0.3},{1.3, 0.5}})); + verify_cost(make::AND(), + calc_cost({{1.1, 0.2},{1.2, 0.3},{1.3, 0.5}}), + calc_cost({{0.2*1.1, 0.2},{1.2, 0.3},{1.3, 0.5}})); } TEST("cost for RANK") { - verify_cost(make::RANK(), 1.1); // first + verify_cost(make::RANK(), 1.1, 0.2*1.1); // first } TEST("cost for ANDNOT") { - verify_cost(make::ANDNOT(), calc_cost({{1.1, 0.2},{1.3, 0.5},{1.2, 0.7}})); + verify_cost(make::ANDNOT(), + calc_cost({{1.1, 0.2},{1.3, 0.5},{1.2, 0.7}}), + calc_cost({{0.2*1.1, 0.2},{1.3, 0.5},{1.2, 0.7}})); } TEST("cost for SB") { InvalidSelector sel; - verify_cost(make::SB(sel), 1.3); // max + verify_cost(make::SB(sel), 1.3, 1.3); // max } TEST("cost for NEAR") { - verify_cost(make::NEAR(1), 3.0 + calc_cost({{1.1, 0.2},{1.2, 0.3},{1.3, 0.5}})); + verify_cost(make::NEAR(1), + 0.2*0.3*0.5 * 3 + calc_cost({{1.1, 0.2},{1.2, 0.3},{1.3, 0.5}}), + 0.2*0.3*0.5 * 3 + calc_cost({{0.2*1.1, 0.2},{1.2, 0.3},{1.3, 0.5}})); } TEST("cost for ONEAR") { - verify_cost(make::ONEAR(1), 3.0 + calc_cost({{1.1, 0.2},{1.2, 0.3},{1.3, 0.5}})); + verify_cost(make::ONEAR(1), + 0.2*0.3*0.5 * 3 + calc_cost({{1.1, 0.2},{1.2, 0.3},{1.3, 0.5}}), + 0.2*0.3*0.5 * 3 + calc_cost({{0.2*1.1, 0.2},{1.2, 0.3},{1.3, 0.5}})); } TEST("cost for WEAKAND") { - verify_cost(make::WEAKAND(1000), calc_cost({{1.3, 0.5},{1.2, 0.7},{1.1, 0.8}})); + verify_cost(make::WEAKAND(1000), + calc_cost({{1.3, 0.5},{1.2, 0.7},{1.1, 0.8}}), + calc_cost({{0.2*1.1, 0.8},{0.3*1.2, 0.7},{1.3, 0.5}})); } TEST_MAIN() { TEST_DEBUG("lhs.out", "rhs.out"); TEST_RUN_ALL(); } diff --git a/searchlib/src/tests/queryeval/blueprint/mysearch.h b/searchlib/src/tests/queryeval/blueprint/mysearch.h index db7dd2adae6..6eb27364c2b 100644 --- a/searchlib/src/tests/queryeval/blueprint/mysearch.h +++ b/searchlib/src/tests/queryeval/blueprint/mysearch.h @@ -104,7 +104,8 @@ public: class MyLeaf : public SimpleLeafBlueprint { using TFMDA = search::fef::TermFieldMatchDataArray; - bool _got_global_filter; + bool _got_global_filter = false; + double _cost = 1.0; public: SearchIterator::UP @@ -113,18 +114,15 @@ public: return std::make_unique<MySearch>("leaf", tfmda, strict); } - MyLeaf() - : SimpleLeafBlueprint(), _got_global_filter(false) - {} - MyLeaf(FieldSpecBaseList fields) - : SimpleLeafBlueprint(std::move(fields)), _got_global_filter(false) - {} - + MyLeaf() : SimpleLeafBlueprint() {} + MyLeaf(FieldSpecBaseList fields) : SimpleLeafBlueprint(std::move(fields)) {} + void set_cost(double value) noexcept { _cost = value; } + double calculate_cost() const override { return _cost; } + MyLeaf &estimate(uint32_t hits, bool empty = false) { setEstimate(HitEstimate(hits, empty)); return *this; } - MyLeaf &cost_tier(uint32_t value) { set_cost_tier(value); return *this; @@ -153,7 +151,7 @@ private: public: explicit MyLeafSpec(uint32_t estHits, bool empty = false) - : _fields(), _estimate(estHits, empty), _cost_tier(0), _want_global_filter(false) {} + : _fields(), _estimate(estHits, empty), _cost_tier(0), _want_global_filter(false) {} MyLeafSpec &addField(uint32_t fieldId, uint32_t handle) { _fields.add(FieldSpecBase(fieldId, handle)); diff --git a/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp b/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp index 1180206279d..4fc8922b9a3 100644 --- a/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp +++ b/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp @@ -48,7 +48,10 @@ concept ChildCollector = requires(T a, std::unique_ptr<Blueprint> bp) { // inherit Blueprint to capture the default filter factory struct DefaultBlueprint : Blueprint { double calculate_relative_estimate() const override { abort(); } - void optimize(Blueprint* &, OptimizePass, bool) override { abort(); } + double calculate_cost() const override { abort(); } + double calculate_strict_cost() const override { abort(); } + void optimize(Blueprint* &, OptimizePass) override { abort(); } + void sort(bool, bool) override { abort(); } const State &getState() const override { abort(); } void fetchPostings(const ExecuteInfo &) override { abort(); } void freeze() override { abort(); } diff --git a/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp b/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp index a9f549a0bd9..7a7abb20cdf 100644 --- a/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp +++ b/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp @@ -626,12 +626,13 @@ TEST(ParallelWeakAndTest, require_that_asString_on_blueprint_works) " estimate: HitEstimate {\n" " empty: false\n" " estHits: 2\n" - " relative_estimate: 0.5\n" " cost_tier: 1\n" " tree_size: 2\n" " allow_termwise_eval: false\n" " }\n" - " cost: 1\n" + " relative_estimate: 0\n" + " cost: 0\n" + " strict_cost: 0\n" " sourceId: 4294967295\n" " docid_limit: 0\n" " _weights: std::vector {\n" @@ -650,12 +651,13 @@ TEST(ParallelWeakAndTest, require_that_asString_on_blueprint_works) " estimate: HitEstimate {\n" " empty: false\n" " estHits: 2\n" - " relative_estimate: 0.5\n" " cost_tier: 1\n" " tree_size: 1\n" " allow_termwise_eval: true\n" " }\n" - " cost: 1\n" + " relative_estimate: 0\n" + " cost: 0\n" + " strict_cost: 0\n" " sourceId: 4294967295\n" " docid_limit: 0\n" " }\n" diff --git a/searchlib/src/tests/queryeval/same_element/same_element_test.cpp b/searchlib/src/tests/queryeval/same_element/same_element_test.cpp index 7c535e5d3d5..c9fcb472b68 100644 --- a/searchlib/src/tests/queryeval/same_element/same_element_test.cpp +++ b/searchlib/src/tests/queryeval/same_element/same_element_test.cpp @@ -46,7 +46,7 @@ std::unique_ptr<SameElementBlueprint> make_blueprint(const std::vector<FakeResul } Blueprint::UP finalize(Blueprint::UP bp, bool strict) { - Blueprint::UP result = Blueprint::optimize(std::move(bp), true); + Blueprint::UP result = Blueprint::optimize_and_sort(std::move(bp), true, true); result->fetchPostings(ExecuteInfo::createForTest(strict)); result->freeze(); return result; diff --git a/searchlib/src/vespa/searchcommon/attribute/basictype.cpp b/searchlib/src/vespa/searchcommon/attribute/basictype.cpp index c63d07ca130..0312154690c 100644 --- a/searchlib/src/vespa/searchcommon/attribute/basictype.cpp +++ b/searchlib/src/vespa/searchcommon/attribute/basictype.cpp @@ -35,4 +35,13 @@ BasicType::asType(const vespalib::string &t) return NONE; } +bool +BasicType::is_integer_type() const noexcept +{ + return (_type == INT8) || + (_type == INT16) || + (_type == INT32) || + (_type == INT64); +} + } diff --git a/searchlib/src/vespa/searchcommon/attribute/basictype.h b/searchlib/src/vespa/searchcommon/attribute/basictype.h index 35200b3f62d..4bdee5ecfd9 100644 --- a/searchlib/src/vespa/searchcommon/attribute/basictype.h +++ b/searchlib/src/vespa/searchcommon/attribute/basictype.h @@ -36,6 +36,7 @@ class BasicType Type type() const noexcept { return _type; } const char * asString() const noexcept { return asString(_type); } size_t fixedSize() const noexcept { return fixedSize(_type); } + bool is_integer_type() const noexcept; static BasicType fromType(bool) noexcept { return BOOL; } static BasicType fromType(int8_t) noexcept { return INT8; } static BasicType fromType(int16_t) noexcept { return INT16; } diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index 5d689f5bd81..d5c67664a5e 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -389,12 +389,6 @@ private: //----------------------------------------------------------------------------- - - - - -//----------------------------------------------------------------------------- - class DirectWandBlueprint : public queryeval::ComplexLeafBlueprint { private: @@ -494,69 +488,6 @@ AttributeFieldBlueprint::getRange(vespalib::string &from, vespalib::string &to) return false; } -//----------------------------------------------------------------------------- - -class DirectAttributeBlueprint : public queryeval::SimpleLeafBlueprint -{ -private: - const IAttributeVector &_iattr; - const IDocidWithWeightPostingStore &_attr; - vespalib::datastore::EntryRef _dictionary_snapshot; - IDirectPostingStore::LookupResult _dict_entry; - -public: - DirectAttributeBlueprint(const FieldSpec &field, const IAttributeVector &iattr, - const IDocidWithWeightPostingStore &attr, - const IDirectPostingStore::LookupKey & key) - : SimpleLeafBlueprint(field), - _iattr(iattr), - _attr(attr), - _dictionary_snapshot(_attr.get_dictionary_snapshot()), - _dict_entry(_attr.lookup(key, _dictionary_snapshot)) - { - setEstimate(HitEstimate(_dict_entry.posting_size, (_dict_entry.posting_size == 0))); - } - - SearchIterator::UP createLeafSearch(const TermFieldMatchDataArray &tfmda, bool strict) const override { - assert(tfmda.size() == 1); - if (_dict_entry.posting_size == 0) { - return std::make_unique<queryeval::EmptySearch>(); - } - if (tfmda[0]->isNotNeeded()) { - auto bitvector_iterator = _attr.make_bitvector_iterator(_dict_entry.posting_idx, get_docid_limit(), *tfmda[0], strict); - if (bitvector_iterator) { - return bitvector_iterator; - } - } - if (_attr.has_btree_iterator(_dict_entry.posting_idx)) { - return std::make_unique<queryeval::DocidWithWeightSearchIterator>(*tfmda[0], _attr, _dict_entry); - } else { - return _attr.make_bitvector_iterator(_dict_entry.posting_idx, get_docid_limit(), *tfmda[0], strict); - } - } - - SearchIteratorUP createFilterSearch(bool strict, FilterConstraint constraint) const override { - (void) constraint; // We provide an iterator with exact results, so no need to take constraint into consideration. - auto wrapper = std::make_unique<FilterWrapper>(getState().numFields()); - wrapper->wrap(createLeafSearch(wrapper->tfmda(), strict)); - return wrapper; - } - - void visitMembers(vespalib::ObjectVisitor &visitor) const override { - LeafBlueprint::visitMembers(visitor); - visit_attribute(visitor, _iattr); - } - std::unique_ptr<queryeval::MatchingElementsSearch> create_matching_elements_search(const MatchingElementsFields &fields) const override { - if (fields.has_field(_iattr.getName())) { - return queryeval::MatchingElementsSearch::create(_iattr, _dictionary_snapshot, vespalib::ConstArrayRef<IDirectPostingStore::LookupResult>(&_dict_entry, 1)); - } else { - return {}; - } - } -}; - -//----------------------------------------------------------------------------- - bool check_valid_diversity_attr(const IAttributeVector *attr) { if ((attr == nullptr) || attr->hasMultiValue()) { return false; @@ -579,8 +510,7 @@ private: const IDocidWithWeightPostingStore *_dwwps; vespalib::string _scratchPad; - bool use_docid_with_weight_posting_store() const { - // TODO: Relax requirement on always having weight iterator for query operators where that makes sense. + bool has_always_btree_iterators_with_docid_and_weight() const { return (_dwwps != nullptr) && (_dwwps->has_always_btree_iterator()); } @@ -598,15 +528,6 @@ public: ~CreateBlueprintVisitor() override; template <class TermNode> - void visitSimpleTerm(TermNode &n) { - if (use_docid_with_weight_posting_store() && !_field.isFilter() && n.isRanked() && !Term::isPossibleRangeTerm(n.getTerm())) { - NodeAsKey key(n, _scratchPad); - setResult(std::make_unique<DirectAttributeBlueprint>(_field, _attr, *_dwwps, key)); - } else { - visitTerm(n); - } - } - template <class TermNode> void visitTerm(TermNode &n) { SearchContextParams scParams = createContextParams(_field.isFilter()); scParams.fuzzy_matching_algorithm(getRequestContext().get_attribute_blueprint_params().fuzzy_matching_algorithm); @@ -628,7 +549,7 @@ public: } } - void visit(NumberTerm & n) override { visitSimpleTerm(n); } + void visit(NumberTerm & n) override { visitTerm(n); } void visit(LocationTerm &n) override { visitLocation(n); } void visit(PrefixTerm & n) override { visitTerm(n); } @@ -652,7 +573,7 @@ public: } } - void visit(StringTerm & n) override { visitSimpleTerm(n); } + void visit(StringTerm & n) override { visitTerm(n); } void visit(SubstringTerm & n) override { query::SimpleRegExpTerm re(vespalib::RegexpUtil::make_from_substring(n.getTerm()), n.getView(), n.getId(), n.getWeight()); @@ -680,32 +601,41 @@ public: return std::make_unique<QueryTermUCS4>(term, QueryTermSimple::Type::WORD); } - void visit(query::WeightedSetTerm &n) override { - bool isSingleValue = !_attr.hasMultiValue(); - bool isString = (_attr.isStringType() && _attr.hasEnum()); - bool isInteger = _attr.isIntegerType(); - if (isSingleValue && (isString || isInteger)) { - auto ws = std::make_unique<AttributeWeightedSetBlueprint>(_field, _attr); - SearchContextParams scParams = createContextParams(); - for (size_t i = 0; i < n.getNumTerms(); ++i) { - auto term = n.getAsString(i); - ws->addToken(_attr.createSearchContext(extractTerm(term.first, isInteger), scParams), term.second.percent()); - } - setResult(std::move(ws)); + template <typename TermType, typename SearchType> + void visit_wset_or_in_term(TermType& n) { + if (_dps != nullptr) { + auto* bp = new attribute::DirectMultiTermBlueprint<IDocidPostingStore, SearchType> + (_field, _attr, *_dps, n.getNumTerms()); + createDirectMultiTerm(bp, n); + } else if (_dwwps != nullptr) { + auto* bp = new attribute::DirectMultiTermBlueprint<IDocidWithWeightPostingStore, SearchType> + (_field, _attr, *_dwwps, n.getNumTerms()); + createDirectMultiTerm(bp, n); } else { - if (use_docid_with_weight_posting_store()) { - auto *bp = new attribute::DirectMultiTermBlueprint<IDocidWithWeightPostingStore, queryeval::WeightedSetTermSearch> - (_field, _attr, *_dwwps, n.getNumTerms()); - createDirectMultiTerm(bp, n); + bool isSingleValue = !_attr.hasMultiValue(); + bool isString = (_attr.isStringType() && _attr.hasEnum()); + bool isInteger = _attr.isIntegerType(); + if (isSingleValue && (isString || isInteger)) { + auto ws = std::make_unique<AttributeWeightedSetBlueprint>(_field, _attr); + SearchContextParams scParams = createContextParams(); + for (size_t i = 0; i < n.getNumTerms(); ++i) { + auto term = n.getAsString(i); + ws->addToken(_attr.createSearchContext(extractTerm(term.first, isInteger), scParams), term.second.percent()); + } + setResult(std::move(ws)); } else { - auto *bp = new WeightedSetTermBlueprint(_field); + auto* bp = new WeightedSetTermBlueprint(_field); createShallowWeightedSet(bp, n, _field, _attr.isIntegerType()); } } } + void visit(query::WeightedSetTerm &n) override { + visit_wset_or_in_term<query::WeightedSetTerm, queryeval::WeightedSetTermSearch>(n); + } + void visit(query::DotProduct &n) override { - if (use_docid_with_weight_posting_store()) { + if (has_always_btree_iterators_with_docid_and_weight()) { auto *bp = new attribute::DirectMultiTermBlueprint<IDocidWithWeightPostingStore, queryeval::DotProductSearch> (_field, _attr, *_dwwps, n.getNumTerms()); createDirectMultiTerm(bp, n); @@ -716,7 +646,7 @@ public: } void visit(query::WandTerm &n) override { - if (use_docid_with_weight_posting_store()) { + if (has_always_btree_iterators_with_docid_and_weight()) { auto *bp = new DirectWandBlueprint(_field, *_dwwps, n.getTargetNumHits(), n.getScoreThreshold(), n.getThresholdBoostFactor(), n.getNumTerms()); @@ -731,18 +661,7 @@ public: } void visit(query::InTerm &n) override { - if (_dps != nullptr) { - auto* bp = new attribute::DirectMultiTermBlueprint<IDocidPostingStore, attribute::InTermSearch> - (_field, _attr, *_dps, n.getNumTerms()); - createDirectMultiTerm(bp, n); - } else if (_dwwps != nullptr) { - auto* bp = new attribute::DirectMultiTermBlueprint<IDocidWithWeightPostingStore, attribute::InTermSearch> - (_field, _attr, *_dwwps, n.getNumTerms()); - createDirectMultiTerm(bp, n); - } else { - auto* bp = new WeightedSetTermBlueprint(_field); - createShallowWeightedSet(bp, n, _field, _attr.isIntegerType()); - } + visit_wset_or_in_term<query::InTerm, attribute::InTermSearch>(n); } void fail_nearest_neighbor_term(query::NearestNeighborTerm&n, const vespalib::string& error_msg) { diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp index 01148c11c9c..d0353ab8947 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp @@ -6,12 +6,12 @@ #include <vespa/searchlib/common/bitvector.h> #include <vespa/searchlib/fef/matchdatalayout.h> #include <vespa/searchlib/query/query_term_ucs4.h> +#include <vespa/searchlib/queryeval/filter_wrapper.h> +#include <vespa/searchlib/queryeval/orsearch.h> #include <vespa/searchlib/queryeval/weighted_set_term_search.h> #include <vespa/vespalib/objects/visit.h> -#include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/stllike/hash_map.hpp> -#include <vespa/searchlib/queryeval/filter_wrapper.h> -#include <vespa/searchlib/queryeval/orsearch.h> +#include <vespa/vespalib/util/stringfmt.h> namespace search { @@ -38,6 +38,7 @@ class StringEnumWrapper : public AttrWrapper { public: using TokenT = uint32_t; + static constexpr bool unpack_weights = true; explicit StringEnumWrapper(const IAttributeVector & attr) : AttrWrapper(attr) {} auto mapToken(const ISearchContext &context) const { @@ -52,6 +53,7 @@ class IntegerWrapper : public AttrWrapper { public: using TokenT = uint64_t; + static constexpr bool unpack_weights = true; explicit IntegerWrapper(const IAttributeVector & attr) : AttrWrapper(attr) {} std::vector<int64_t> mapToken(const ISearchContext &context) const { std::vector<int64_t> result; diff --git a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h index 42651854599..485427391ad 100644 --- a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h +++ b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h @@ -35,6 +35,8 @@ private: using IteratorType = typename PostingStoreType::IteratorType; using IteratorWeights = std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>>; + bool use_hash_filter(bool strict) const; + IteratorWeights create_iterators(std::vector<IteratorType>& btree_iterators, std::vector<std::unique_ptr<queryeval::SearchIterator>>& bitvectors, bool use_bitvector_when_available, diff --git a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp index 6fc8ac63026..0a3b24142a5 100644 --- a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp +++ b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp @@ -8,6 +8,7 @@ #include <vespa/searchlib/queryeval/emptysearch.h> #include <vespa/searchlib/queryeval/filter_wrapper.h> #include <vespa/searchlib/queryeval/orsearch.h> +#include <cmath> #include <memory> #include <type_traits> @@ -38,6 +39,43 @@ DirectMultiTermBlueprint<PostingStoreType, SearchType>::DirectMultiTermBlueprint template <typename PostingStoreType, typename SearchType> DirectMultiTermBlueprint<PostingStoreType, SearchType>::~DirectMultiTermBlueprint() = default; + +template <typename PostingStoreType, typename SearchType> +bool +DirectMultiTermBlueprint<PostingStoreType, SearchType>::use_hash_filter(bool strict) const +{ + if (strict || _iattr.hasMultiValue()) { + return false; + } + // The following very simplified formula was created after analysing performance of the IN operator + // on a 10M document corpus using a machine with an Intel Xeon 2.5 GHz CPU with 48 cores and 256 Gb of memory: + // https://github.com/vespa-engine/system-test/tree/master/tests/performance/in_operator + // + // The following 25 test cases were used to calculate the cost of using btree iterators (strict): + // op_hits_ratios = [5, 10, 50, 100, 200] * tokens_in_op = [1, 5, 10, 100, 1000] + // For each case we calculate the latency difference against the case with tokens_in_op=1 and the same op_hits_ratio. + // This indicates the extra time used to produce the same number of hits when having multiple tokens in the operator. + // The latency diff is divided with the number of hits produced and convert to nanoseconds: + // 10M * (op_hits_ratio / 1000) * 1000 * 1000 + // Based on the numbers we can approximate the cost per document (in nanoseconds) as: + // 8.0 (ns) * log2(tokens_in_op). + // NOTE: This is very simplified. Ideally we should also take into consideration the hit estimate of this blueprint, + // as the cost per document will be lower when producing few hits. + // + // In addition, the following 28 test cases were used to calculate the cost of using the hash filter (non-strict). + // filter_hits_ratios = [1, 5, 10, 50, 100, 150, 200] x op_hits_ratios = [200] x tokens_in_op = [5, 10, 100, 1000] + // The code was altered to always using the hash filter for non-strict iterators. + // For each case we calculate the latency difference against a case from above with tokens_in_op=1 that produce a similar number of hits. + // This indicates the extra time used to produce the same number of hits when using the hash filter. + // The latency diff is divided with the number of hits the test filter produces and convert to nanoseconds: + // 10M * (filter_hits_ratio / 1000) * 1000 * 1000 + // Based on the numbers we calculate the average cost per document (in nanoseconds) as 26.0 ns. + + float hash_filter_cost_per_doc_ns = 26.0; + float btree_iterator_cost_per_doc_ns = 8.0 * std::log2(_terms.size()); + return hash_filter_cost_per_doc_ns < btree_iterator_cost_per_doc_ns; +} + template <typename PostingStoreType, typename SearchType> typename DirectMultiTermBlueprint<PostingStoreType, SearchType>::IteratorWeights DirectMultiTermBlueprint<PostingStoreType, SearchType>::create_iterators(std::vector<IteratorType>& btree_iterators, @@ -96,14 +134,21 @@ DirectMultiTermBlueprint<PostingStoreType, SearchType>::create_search_helper(con if (_terms.empty()) { return std::make_unique<queryeval::EmptySearch>(); } + auto& tfmd = *tfmda[0]; + bool field_is_filter = getState().fields()[0].isFilter(); + if constexpr (SearchType::supports_hash_filter) { + if (use_hash_filter(strict)) { + return SearchType::create_hash_filter(tfmd, (filter_search || field_is_filter), + _weights, _terms, + _iattr, _attr, _dictionary_snapshot); + } + } std::vector<IteratorType> btree_iterators; std::vector<queryeval::SearchIterator::UP> bitvectors; const size_t num_children = _terms.size(); btree_iterators.reserve(num_children); - auto& tfmd = *tfmda[0]; bool use_bit_vector_when_available = filter_search || !_attr.has_always_btree_iterator(); auto weights = create_iterators(btree_iterators, bitvectors, use_bit_vector_when_available, tfmd, strict); - bool field_is_filter = getState().fields()[0].isFilter(); if constexpr (!SearchType::require_btree_iterators) { auto multi_term = !btree_iterators.empty() ? SearchType::create(tfmd, (filter_search || field_is_filter), std::move(weights), std::move(btree_iterators)) diff --git a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h index 9c3ea258fdc..43500366f21 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h @@ -23,6 +23,7 @@ class MultiTermHashFilter final : public queryeval::SearchIterator public: using Key = typename WrapperType::TokenT; using TokenMap = vespalib::hash_map<Key, int32_t, vespalib::hash<Key>, std::equal_to<Key>, vespalib::hashtable_base::and_modulator>; + static constexpr bool unpack_weights = WrapperType::unpack_weights; private: fef::TermFieldMatchData& _tfmd; diff --git a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp index 96d5b3ac1f3..e67bda147d7 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp @@ -42,10 +42,14 @@ template <typename WrapperType> void MultiTermHashFilter<WrapperType>::doUnpack(uint32_t docId) { - _tfmd.reset(docId); - fef::TermFieldMatchDataPosition pos; - pos.setElementWeight(_weight); - _tfmd.appendPosition(pos); + if constexpr (unpack_weights) { + _tfmd.reset(docId); + fef::TermFieldMatchDataPosition pos; + pos.setElementWeight(_weight); + _tfmd.appendPosition(pos); + } else { + _tfmd.resetOnlyDocId(docId); + } } } diff --git a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp index ea1058d88fb..9328857c919 100644 --- a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp @@ -89,7 +89,7 @@ template <typename B, typename M> const IDocidWithWeightPostingStore* MultiValueNumericPostingAttribute<B, M>::as_docid_with_weight_posting_store() const { - if (this->hasWeightedSetType() && (this->getBasicType() == AttributeVector::BasicType::INT64)) { + if (this->getConfig().basicType().is_integer_type()) { return &_posting_store_adapter; } return nullptr; diff --git a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp index b6e9b69a81d..371b8d920f9 100644 --- a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp @@ -108,8 +108,7 @@ template <typename B, typename T> const IDocidWithWeightPostingStore* MultiValueStringPostingAttributeT<B, T>::as_docid_with_weight_posting_store() const { - // TODO: Add support for handling bit vectors too, and lift restriction on isFilter. - if (this->hasWeightedSetType() && this->isStringType()) { + if (this->isStringType()) { return &_posting_store_adapter; } return nullptr; diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp index c57742ca4b6..c74e25fbf5a 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp @@ -150,22 +150,11 @@ SingleValueNumericPostingAttribute<B>::getSearch(QueryTermSimple::UP qTerm, return std::make_unique<SC>(std::move(base_sc), params, *this); } -namespace { - -bool is_integer_type(attribute::BasicType type) { - return (type == attribute::BasicType::INT8) || - (type == attribute::BasicType::INT16) || - (type == attribute::BasicType::INT32) || - (type == attribute::BasicType::INT64); -} - -} - template <typename B> const IDocidPostingStore* SingleValueNumericPostingAttribute<B>::as_docid_posting_store() const { - if (is_integer_type(this->getBasicType())) { + if (this->getConfig().basicType().is_integer_type()) { return &_posting_store_adapter; } return nullptr; diff --git a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt index 05a75f4662e..63d52cbdf9f 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt @@ -2,13 +2,19 @@ vespa_add_library(searchlib_query_streaming OBJECT SOURCES dot_product_term.cpp + fuzzy_term.cpp + hit_iterator_pack.cpp in_term.cpp multi_term.cpp + near_query_node.cpp nearest_neighbor_query_node.cpp + onear_query_node.cpp + phrase_query_node.cpp query.cpp querynode.cpp querynoderesultbase.cpp queryterm.cpp + same_element_query_node.cpp wand_term.cpp weighted_set_term.cpp regexp_term.cpp diff --git a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp index 1871bda564d..09840d9a126 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp @@ -24,7 +24,7 @@ DotProductTerm::build_scores(Scores& scores) const for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - scores[hit.context()] += ((int64_t)term->weight().percent()) * hit.weight(); + scores[hit.field_id()] += ((int64_t)term->weight().percent()) * hit.element_weight(); } } } diff --git a/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.cpp new file mode 100644 index 00000000000..f33fe44369a --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.cpp @@ -0,0 +1,43 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "fuzzy_term.h" + +namespace search::streaming { + +namespace { + +constexpr bool normalizing_implies_cased(Normalizing norm) noexcept { + return (norm == Normalizing::NONE); +} + +} + +FuzzyTerm::FuzzyTerm(std::unique_ptr<QueryNodeResultBase> result_base, stringref term, + const string& index, Type type, Normalizing normalizing, + uint8_t max_edits, uint32_t prefix_size) + : QueryTerm(std::move(result_base), term, index, type, normalizing), + _dfa_matcher(), + _fallback_matcher() +{ + setFuzzyMaxEditDistance(max_edits); + setFuzzyPrefixLength(prefix_size); + + std::string_view term_view(term.data(), term.size()); + const bool cased = normalizing_implies_cased(normalizing); + if (attribute::DfaFuzzyMatcher::supports_max_edits(max_edits)) { + _dfa_matcher = std::make_unique<attribute::DfaFuzzyMatcher>(term_view, max_edits, prefix_size, cased); + } else { + _fallback_matcher = std::make_unique<vespalib::FuzzyMatcher>(term_view, max_edits, prefix_size, cased); + } +} + +FuzzyTerm::~FuzzyTerm() = default; + +bool FuzzyTerm::is_match(std::string_view term) const { + if (_dfa_matcher) { + return _dfa_matcher->is_match(term); + } else { + return _fallback_matcher->isMatch(term); + } +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.h b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.h new file mode 100644 index 00000000000..c6c88b18969 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.h @@ -0,0 +1,34 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "queryterm.h" +#include <vespa/searchlib/attribute/dfa_fuzzy_matcher.h> +#include <vespa/vespalib/fuzzy/fuzzy_matcher.h> +#include <memory> +#include <string_view> + +namespace search::streaming { + +/** + * Query term that matches candidate field terms that are within a query-specified + * maximum number of edits (add, delete or substitute a character), with case + * sensitivity controlled by the provided Normalizing mode. + * + * Optionally, terms may be prefixed-locked, which enforces field terms to have a + * particular prefix and where edits are only counted for the remaining term suffix. + */ +class FuzzyTerm : public QueryTerm { + std::unique_ptr<attribute::DfaFuzzyMatcher> _dfa_matcher; + std::unique_ptr<vespalib::FuzzyMatcher> _fallback_matcher; +public: + FuzzyTerm(std::unique_ptr<QueryNodeResultBase> result_base, stringref term, + const string& index, Type type, Normalizing normalizing, + uint8_t max_edits, uint32_t prefix_size); + ~FuzzyTerm() override; + + [[nodiscard]] FuzzyTerm* as_fuzzy_term() noexcept override { return this; } + + [[nodiscard]] bool is_match(std::string_view term) const; +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/hit.h b/searchlib/src/vespa/searchlib/query/streaming/hit.h index a798d293491..cd72555ea66 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/hit.h +++ b/searchlib/src/vespa/searchlib/query/streaming/hit.h @@ -8,23 +8,21 @@ namespace search::streaming { class Hit { + uint32_t _field_id; + uint32_t _element_id; + int32_t _element_weight; + uint32_t _position; public: - Hit(uint32_t pos_, uint32_t context_, uint32_t elemId_, int32_t weight_) noexcept - : _position(pos_ | (context_<<24)), - _elemId(elemId_), - _weight(weight_) + Hit(uint32_t field_id_, uint32_t element_id_, int32_t element_weight_, uint32_t position_) noexcept + : _field_id(field_id_), + _element_id(element_id_), + _element_weight(element_weight_), + _position(position_) { } - int32_t weight() const { return _weight; } - uint32_t pos() const { return _position; } - uint32_t wordpos() const { return _position & 0xffffff; } - uint32_t context() const { return _position >> 24; } - uint32_t elemId() const { return _elemId; } - bool operator < (const Hit & b) const { return cmp(b) < 0; } -private: - int cmp(const Hit & b) const { return _position - b._position; } - uint32_t _position; - uint32_t _elemId; - int32_t _weight; + uint32_t field_id() const noexcept { return _field_id; } + uint32_t element_id() const { return _element_id; } + int32_t element_weight() const { return _element_weight; } + uint32_t position() const { return _position; } }; using HitList = std::vector<Hit>; diff --git a/searchlib/src/vespa/searchlib/query/streaming/hit_iterator.h b/searchlib/src/vespa/searchlib/query/streaming/hit_iterator.h new file mode 100644 index 00000000000..21eba679abd --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/hit_iterator.h @@ -0,0 +1,73 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "hit.h" + +namespace search::streaming { + +/* + * Iterator used over hit list for a term to support near, onear, phrase and + * same element query nodes. + */ +class HitIterator +{ + HitList::const_iterator _cur; + HitList::const_iterator _end; +public: + using FieldElement = std::pair<uint32_t, uint32_t>; + HitIterator(const HitList& hl) noexcept + : _cur(hl.begin()), + _end(hl.end()) + { } + bool valid() const noexcept { return _cur != _end; } + const Hit* operator->() const noexcept { return _cur.operator->(); } + const Hit& operator*() const noexcept { return _cur.operator*(); } + FieldElement get_field_element() const noexcept { return std::make_pair(_cur->field_id(), _cur->element_id()); } + bool seek_to_field_element(const FieldElement& field_element) noexcept { + while (valid()) { + if (!(get_field_element() < field_element)) { + return true; + } + ++_cur; + } + return false; + } + /* + * Step iterator forwards within the scope of the same field + * element. Return true if iterator is valid and with the same + * field element, otherwise return false but update field_element + * if iterator is valid to prepare for hit iterator pack seeking + * to next matching field element. + */ + bool step_in_field_element(FieldElement& field_element) noexcept { + ++_cur; + if (!valid()) { + return false; + } + auto ife = get_field_element(); + if (field_element < ife) { + field_element = ife; + return false; + } + return true; + } + /* + * Seek to position within the scope of the same field element. + * Return true if iterator is valid and with the same field + * element, otherwise return false but update field element if + * iterator is valid to prepare for hit iterator pack seeking to + * next matching field element. + */ + bool seek_in_field_element(uint32_t position, FieldElement& field_element) { + while (_cur->position() < position) { + if (!step_in_field_element(field_element)) { + return false; + } + } + return true; + } + HitIterator& operator++() { ++_cur; return *this; } +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/hit_iterator_pack.cpp b/searchlib/src/vespa/searchlib/query/streaming/hit_iterator_pack.cpp new file mode 100644 index 00000000000..fabd992c379 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/hit_iterator_pack.cpp @@ -0,0 +1,57 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "hit_iterator_pack.h" + +namespace search::streaming { + + +HitIteratorPack::HitIteratorPack(const QueryNodeList& children) + : _iterators(), + _field_element(std::make_pair(0, 0)) +{ + auto num_children = children.size(); + _iterators.reserve(num_children); + for (auto& child : children) { + auto& curr = dynamic_cast<const QueryTerm&>(*child); + _iterators.emplace_back(curr.getHitList()); + } +} + +HitIteratorPack::~HitIteratorPack() = default; + +bool +HitIteratorPack::all_valid() const noexcept +{ + if (_iterators.empty()) { + return false; + } + for (auto& it : _iterators) { + if (!it.valid()) { + return false; + } + } + return true; +} + +bool +HitIteratorPack::seek_to_matching_field_element() noexcept +{ + bool retry = true; + while (retry) { + retry = false; + for (auto& it : _iterators) { + if (!it.seek_to_field_element(_field_element)) { + return false; + } + auto ife = it.get_field_element(); + if (_field_element < ife) { + _field_element = ife; + retry = true; + break; + } + } + } + return true; +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/hit_iterator_pack.h b/searchlib/src/vespa/searchlib/query/streaming/hit_iterator_pack.h new file mode 100644 index 00000000000..200d579930a --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/hit_iterator_pack.h @@ -0,0 +1,32 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "hit_iterator.h" +#include "queryterm.h" + +namespace search::streaming { + +/* + * Iterator pack used over hit list for a term to support near, onear, + * phrase and same element query nodes. + */ +class HitIteratorPack +{ + using iterator = typename std::vector<HitIterator>::iterator; + using FieldElement = HitIterator::FieldElement; + std::vector<HitIterator> _iterators; + FieldElement _field_element; +public: + explicit HitIteratorPack(const QueryNodeList& children); + ~HitIteratorPack(); + FieldElement& get_field_element_ref() noexcept { return _field_element; } + HitIterator& front() noexcept { return _iterators.front(); } + HitIterator& back() noexcept { return _iterators.back(); } + iterator begin() noexcept { return _iterators.begin(); } + iterator end() noexcept { return _iterators.end(); } + bool all_valid() const noexcept; + bool seek_to_matching_field_element() noexcept; +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp index 3e75f4a5114..c164db69ba1 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp @@ -29,9 +29,9 @@ InTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_ for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - if (!prev_field_id.has_value() || prev_field_id.value() != hit.context()) { - prev_field_id = hit.context(); - matching_field_ids.insert(hit.context()); + if (!prev_field_id.has_value() || prev_field_id.value() != hit.field_id()) { + prev_field_id = hit.field_id(); + matching_field_ids.insert(hit.field_id()); } } } diff --git a/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp index dd34b9b7e73..f5a09892551 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp @@ -51,11 +51,4 @@ MultiTerm::evaluate() const return false; } -const HitList& -MultiTerm::evaluateHits(HitList& hl) const -{ - hl.clear(); - return hl; -} - } diff --git a/searchlib/src/vespa/searchlib/query/streaming/multi_term.h b/searchlib/src/vespa/searchlib/query/streaming/multi_term.h index 3bb69e29693..6f795c31356 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/multi_term.h +++ b/searchlib/src/vespa/searchlib/query/streaming/multi_term.h @@ -32,7 +32,6 @@ public: MultiTerm* as_multi_term() noexcept override { return this; } void reset() override; bool evaluate() const override; - const HitList& evaluateHits(HitList& hl) const override; virtual void unpack_match_data(uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) = 0; const std::vector<std::unique_ptr<QueryTerm>>& get_terms() const noexcept { return _terms; } }; diff --git a/searchlib/src/vespa/searchlib/query/streaming/near_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/near_query_node.cpp new file mode 100644 index 00000000000..126ed1ff69e --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/near_query_node.cpp @@ -0,0 +1,21 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "near_query_node.h" +#include <vespa/vespalib/objects/visit.hpp> + +namespace search::streaming { + +bool +NearQueryNode::evaluate() const +{ + return AndQueryNode::evaluate(); +} + +void +NearQueryNode::visitMembers(vespalib::ObjectVisitor &visitor) const +{ + AndQueryNode::visitMembers(visitor); + visit(visitor, "distance", static_cast<uint64_t>(_distance)); +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/near_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/near_query_node.h new file mode 100644 index 00000000000..63b7e748296 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/near_query_node.h @@ -0,0 +1,26 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "query.h" + +namespace search::streaming { + +/** + N-ary Near operator. All terms must be within the given distance. +*/ +class NearQueryNode : public AndQueryNode +{ +public: + NearQueryNode() noexcept : AndQueryNode("NEAR"), _distance(0) { } + explicit NearQueryNode(const char * opName) noexcept : AndQueryNode(opName), _distance(0) { } + bool evaluate() const override; + void distance(size_t dist) { _distance = dist; } + size_t distance() const { return _distance; } + void visitMembers(vespalib::ObjectVisitor &visitor) const override; + bool isFlattenable(ParseItem::ItemType type) const override { return type == ParseItem::ITEM_NOT; } +private: + size_t _distance; +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/onear_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/onear_query_node.cpp new file mode 100644 index 00000000000..a06ff0f90e1 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/onear_query_node.cpp @@ -0,0 +1,14 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "onear_query_node.h" + +namespace search::streaming { + +bool +ONearQueryNode::evaluate() const +{ + bool ok(NearQueryNode::evaluate()); + return ok; +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/onear_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/onear_query_node.h new file mode 100644 index 00000000000..649496b62d9 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/onear_query_node.h @@ -0,0 +1,20 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "near_query_node.h" + +namespace search::streaming { + +/** + N-ary Ordered near operator. The terms must be in order and the distance between + the first and last must not exceed the given distance. +*/ +class ONearQueryNode : public NearQueryNode +{ +public: + ONearQueryNode() noexcept : NearQueryNode("ONEAR") { } + bool evaluate() const override; +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/phrase_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/phrase_query_node.cpp new file mode 100644 index 00000000000..2d2778417fa --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/phrase_query_node.cpp @@ -0,0 +1,82 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "phrase_query_node.h" +#include "hit_iterator_pack.h" +#include <cassert> + +namespace search::streaming { + +bool +PhraseQueryNode::evaluate() const +{ + HitList hl; + return ! evaluateHits(hl).empty(); +} + +void PhraseQueryNode::getPhrases(QueryNodeRefList & tl) { tl.push_back(this); } +void PhraseQueryNode::getPhrases(ConstQueryNodeRefList & tl) const { tl.push_back(this); } + +void +PhraseQueryNode::addChild(QueryNode::UP child) { + assert(dynamic_cast<const QueryTerm *>(child.get()) != nullptr); + AndQueryNode::addChild(std::move(child)); +} + +const HitList & +PhraseQueryNode::evaluateHits(HitList & hl) const +{ + hl.clear(); + _fieldInfo.clear(); + HitIteratorPack itr_pack(getChildren()); + if (!itr_pack.all_valid()) { + return hl; + } + auto& last_child = dynamic_cast<const QueryTerm&>(*(*this)[size() - 1]); + while (itr_pack.seek_to_matching_field_element()) { + uint32_t first_position = itr_pack.front()->position(); + bool retry_element = true; + while (retry_element) { + uint32_t position_offset = 0; + bool match = true; + for (auto& it : itr_pack) { + if (!it.seek_in_field_element(first_position + position_offset, itr_pack.get_field_element_ref())) { + retry_element = false; + match = false; + break; + } + if (it->position() > first_position + position_offset) { + first_position = it->position() - position_offset; + match = false; + break; + } + ++position_offset; + } + if (match) { + auto h = *itr_pack.back(); + hl.push_back(h); + auto& fi = last_child.getFieldInfo(h.field_id()); + updateFieldInfo(h.field_id(), hl.size() - 1, fi.getFieldLength()); + if (!itr_pack.front().step_in_field_element(itr_pack.get_field_element_ref())) { + retry_element = false; + } + } + } + } + return hl; +} + +void +PhraseQueryNode::updateFieldInfo(size_t fid, size_t offset, size_t fieldLength) const +{ + if (fid >= _fieldInfo.size()) { + _fieldInfo.resize(fid + 1); + // only set hit offset and field length the first time + QueryTerm::FieldInfo & fi = _fieldInfo[fid]; + fi.setHitOffset(offset); + fi.setFieldLength(fieldLength); + } + QueryTerm::FieldInfo & fi = _fieldInfo[fid]; + fi.setHitCount(fi.getHitCount() + 1); +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/phrase_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/phrase_query_node.h new file mode 100644 index 00000000000..f28684ffeb4 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/phrase_query_node.h @@ -0,0 +1,34 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "query.h" + +namespace search::streaming { + +/** + N-ary phrase operator. All terms must be satisfied and have the correct order + with distance to next term equal to 1. +*/ +class PhraseQueryNode : public AndQueryNode +{ +public: + PhraseQueryNode() noexcept : AndQueryNode("PHRASE"), _fieldInfo(32) { } + bool evaluate() const override; + const HitList & evaluateHits(HitList & hl) const override; + void getPhrases(QueryNodeRefList & tl) override; + void getPhrases(ConstQueryNodeRefList & tl) const override; + const QueryTerm::FieldInfo & getFieldInfo(size_t fid) const { return _fieldInfo[fid]; } + size_t getFieldInfoSize() const { return _fieldInfo.size(); } + bool isFlattenable(ParseItem::ItemType type) const override { return type == ParseItem::ITEM_NOT; } + void addChild(QueryNode::UP child) override; +private: + mutable std::vector<QueryTerm::FieldInfo> _fieldInfo; + void updateFieldInfo(size_t fid, size_t offset, size_t fieldLength) const; +#if WE_EVER_NEED_TO_CACHE_THIS_WE_MIGHT_WANT_SOME_CODE_HERE + HitList _cachedHitList; + bool _evaluated; +#endif +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.cpp b/searchlib/src/vespa/searchlib/query/streaming/query.cpp index ca742aabe26..5b0076a30c5 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/query.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/query.cpp @@ -1,5 +1,9 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "query.h" +#include "near_query_node.h" +#include "onear_query_node.h" +#include "phrase_query_node.h" +#include "same_element_query_node.h" #include <vespa/searchlib/parsequery/stackdumpiterator.h> #include <vespa/vespalib/objects/visit.hpp> #include <cassert> @@ -31,7 +35,7 @@ const HitList & QueryConnector::evaluateHits(HitList & hl) const { if (evaluate()) { - hl.emplace_back(1, 0, 0, 1); + hl.emplace_back(0, 0, 1, 1); } return hl; } @@ -113,6 +117,7 @@ QueryConnector::create(ParseItem::ItemType type) case search::ParseItem::ITEM_SAME_ELEMENT: return std::make_unique<SameElementQueryNode>(); case search::ParseItem::ITEM_NEAR: return std::make_unique<NearQueryNode>(); case search::ParseItem::ITEM_ONEAR: return std::make_unique<ONearQueryNode>(); + case search::ParseItem::ITEM_RANK: return std::make_unique<RankWithQueryNode>(); default: return nullptr; } } @@ -158,174 +163,23 @@ OrQueryNode::evaluate() const { return false; } - -bool -EquivQueryNode::evaluate() const -{ - return OrQueryNode::evaluate(); -} - bool -SameElementQueryNode::evaluate() const { - HitList hl; - return ! evaluateHits(hl).empty(); -} - -void -SameElementQueryNode::addChild(QueryNode::UP child) { - assert(dynamic_cast<const QueryTerm *>(child.get()) != nullptr); - AndQueryNode::addChild(std::move(child)); -} - -const HitList & -SameElementQueryNode::evaluateHits(HitList & hl) const -{ - hl.clear(); - if ( !AndQueryNode::evaluate()) return hl; - - HitList tmpHL; - const auto & children = getChildren(); - unsigned int numFields = children.size(); - unsigned int currMatchCount = 0; - std::vector<unsigned int> indexVector(numFields, 0); - auto curr = static_cast<const QueryTerm *> (children[currMatchCount].get()); - bool exhausted( curr->evaluateHits(tmpHL).empty()); - for (; !exhausted; ) { - auto next = static_cast<const QueryTerm *>(children[currMatchCount+1].get()); - unsigned int & currIndex = indexVector[currMatchCount]; - unsigned int & nextIndex = indexVector[currMatchCount+1]; - - const auto & currHit = curr->evaluateHits(tmpHL)[currIndex]; - uint32_t currElemId = currHit.elemId(); - - const HitList & nextHL = next->evaluateHits(tmpHL); - - size_t nextIndexMax = nextHL.size(); - while ((nextIndex < nextIndexMax) && (nextHL[nextIndex].elemId() < currElemId)) { - nextIndex++; - } - if ((nextIndex < nextIndexMax) && (nextHL[nextIndex].elemId() == currElemId)) { - currMatchCount++; - if ((currMatchCount+1) == numFields) { - Hit h = nextHL[indexVector[currMatchCount]]; - hl.emplace_back(0, h.context(), h.elemId(), h.weight()); - currMatchCount = 0; - indexVector[0]++; - } - } else { - currMatchCount = 0; - indexVector[currMatchCount]++; - } - curr = static_cast<const QueryTerm *>(children[currMatchCount].get()); - exhausted = (nextIndex >= nextIndexMax) || (indexVector[currMatchCount] >= curr->evaluateHits(tmpHL).size()); - } - return hl; -} - -bool -PhraseQueryNode::evaluate() const -{ - HitList hl; - return ! evaluateHits(hl).empty(); -} - -void PhraseQueryNode::getPhrases(QueryNodeRefList & tl) { tl.push_back(this); } -void PhraseQueryNode::getPhrases(ConstQueryNodeRefList & tl) const { tl.push_back(this); } - -void -PhraseQueryNode::addChild(QueryNode::UP child) { - assert(dynamic_cast<const QueryTerm *>(child.get()) != nullptr); - AndQueryNode::addChild(std::move(child)); -} - -const HitList & -PhraseQueryNode::evaluateHits(HitList & hl) const -{ - hl.clear(); - _fieldInfo.clear(); - if ( ! AndQueryNode::evaluate()) return hl; - - HitList tmpHL; - const auto & children = getChildren(); - unsigned int fullPhraseLen = children.size(); - unsigned int currPhraseLen = 0; - std::vector<unsigned int> indexVector(fullPhraseLen, 0); - auto curr = static_cast<const QueryTerm *> (children[currPhraseLen].get()); - bool exhausted( curr->evaluateHits(tmpHL).empty()); - for (; !exhausted; ) { - auto next = static_cast<const QueryTerm *>(children[currPhraseLen+1].get()); - unsigned int & currIndex = indexVector[currPhraseLen]; - unsigned int & nextIndex = indexVector[currPhraseLen+1]; - - const auto & currHit = curr->evaluateHits(tmpHL)[currIndex]; - size_t firstPosition = currHit.pos(); - uint32_t currElemId = currHit.elemId(); - uint32_t currContext = currHit.context(); - - const HitList & nextHL = next->evaluateHits(tmpHL); - - int diff(0); - size_t nextIndexMax = nextHL.size(); - while ((nextIndex < nextIndexMax) && - ((nextHL[nextIndex].context() < currContext) || - ((nextHL[nextIndex].context() == currContext) && (nextHL[nextIndex].elemId() <= currElemId))) && - ((diff = nextHL[nextIndex].pos()-firstPosition) < 1)) - { - nextIndex++; - } - if ((diff == 1) && (nextHL[nextIndex].context() == currContext) && (nextHL[nextIndex].elemId() == currElemId)) { - currPhraseLen++; - if ((currPhraseLen+1) == fullPhraseLen) { - Hit h = nextHL[indexVector[currPhraseLen]]; - hl.push_back(h); - const QueryTerm::FieldInfo & fi = next->getFieldInfo(h.context()); - updateFieldInfo(h.context(), hl.size() - 1, fi.getFieldLength()); - currPhraseLen = 0; - indexVector[0]++; - } - } else { - currPhraseLen = 0; - indexVector[currPhraseLen]++; +RankWithQueryNode::evaluate() const { + bool first = true; + bool firstOk = false; + for (const auto & qn : getChildren()) { + if (qn->evaluate()) { + if (first) firstOk = true; } - curr = static_cast<const QueryTerm *>(children[currPhraseLen].get()); - exhausted = (nextIndex >= nextIndexMax) || (indexVector[currPhraseLen] >= curr->evaluateHits(tmpHL).size()); - } - return hl; -} - -void -PhraseQueryNode::updateFieldInfo(size_t fid, size_t offset, size_t fieldLength) const -{ - if (fid >= _fieldInfo.size()) { - _fieldInfo.resize(fid + 1); - // only set hit offset and field length the first time - QueryTerm::FieldInfo & fi = _fieldInfo[fid]; - fi.setHitOffset(offset); - fi.setFieldLength(fieldLength); + first = false; } - QueryTerm::FieldInfo & fi = _fieldInfo[fid]; - fi.setHitCount(fi.getHitCount() + 1); -} - -bool -NearQueryNode::evaluate() const -{ - return AndQueryNode::evaluate(); -} - -void -NearQueryNode::visitMembers(vespalib::ObjectVisitor &visitor) const -{ - AndQueryNode::visitMembers(visitor); - visit(visitor, "distance", static_cast<uint64_t>(_distance)); + return firstOk; } - bool -ONearQueryNode::evaluate() const +EquivQueryNode::evaluate() const { - bool ok(NearQueryNode::evaluate()); - return ok; + return OrQueryNode::evaluate(); } Query::Query() = default; diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.h b/searchlib/src/vespa/searchlib/query/streaming/query.h index 84c693b86d0..4ab33a01d86 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/query.h +++ b/searchlib/src/vespa/searchlib/query/streaming/query.h @@ -95,79 +95,28 @@ public: }; /** - N-ary "EQUIV" operator that merges terms from nodes below. + N-ary RankWith operator */ -class EquivQueryNode : public OrQueryNode +class RankWithQueryNode : public QueryConnector { public: - EquivQueryNode() noexcept : OrQueryNode("EQUIV") { } + RankWithQueryNode() noexcept : QueryConnector("RANK") { } + explicit RankWithQueryNode(const char * opName) noexcept : QueryConnector(opName) { } bool evaluate() const override; - bool isFlattenable(ParseItem::ItemType type) const override { - return (type == ParseItem::ITEM_EQUIV); - } }; -/** - N-ary phrase operator. All terms must be satisfied and have the correct order - with distance to next term equal to 1. -*/ -class PhraseQueryNode : public AndQueryNode -{ -public: - PhraseQueryNode() noexcept : AndQueryNode("PHRASE"), _fieldInfo(32) { } - bool evaluate() const override; - const HitList & evaluateHits(HitList & hl) const override; - void getPhrases(QueryNodeRefList & tl) override; - void getPhrases(ConstQueryNodeRefList & tl) const override; - const QueryTerm::FieldInfo & getFieldInfo(size_t fid) const { return _fieldInfo[fid]; } - size_t getFieldInfoSize() const { return _fieldInfo.size(); } - bool isFlattenable(ParseItem::ItemType type) const override { return type == ParseItem::ITEM_NOT; } - void addChild(QueryNode::UP child) override; -private: - mutable std::vector<QueryTerm::FieldInfo> _fieldInfo; - void updateFieldInfo(size_t fid, size_t offset, size_t fieldLength) const; -#if WE_EVER_NEED_TO_CACHE_THIS_WE_MIGHT_WANT_SOME_CODE_HERE - HitList _cachedHitList; - bool _evaluated; -#endif -}; - -class SameElementQueryNode : public AndQueryNode -{ -public: - SameElementQueryNode() noexcept : AndQueryNode("SAME_ELEMENT") { } - bool evaluate() const override; - const HitList & evaluateHits(HitList & hl) const override; - bool isFlattenable(ParseItem::ItemType type) const override { return type == ParseItem::ITEM_NOT; } - void addChild(QueryNode::UP child) override; -}; /** - N-ary Near operator. All terms must be within the given distance. -*/ -class NearQueryNode : public AndQueryNode -{ -public: - NearQueryNode() noexcept : AndQueryNode("NEAR"), _distance(0) { } - explicit NearQueryNode(const char * opName) noexcept : AndQueryNode(opName), _distance(0) { } - bool evaluate() const override; - void distance(size_t dist) { _distance = dist; } - size_t distance() const { return _distance; } - void visitMembers(vespalib::ObjectVisitor &visitor) const override; - bool isFlattenable(ParseItem::ItemType type) const override { return type == ParseItem::ITEM_NOT; } -private: - size_t _distance; -}; - -/** - N-ary Ordered near operator. The terms must be in order and the distance between - the first and last must not exceed the given distance. + N-ary "EQUIV" operator that merges terms from nodes below. */ -class ONearQueryNode : public NearQueryNode +class EquivQueryNode : public OrQueryNode { public: - ONearQueryNode() noexcept : NearQueryNode("ONEAR") { } + EquivQueryNode() noexcept : OrQueryNode("EQUIV") { } bool evaluate() const override; + bool isFlattenable(ParseItem::ItemType type) const override { + return (type == ParseItem::ITEM_EQUIV); + } }; /** diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp index 2ee515f062a..32e3ec16b16 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp @@ -1,8 +1,12 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "query.h" +#include "fuzzy_term.h" +#include "near_query_node.h" #include "nearest_neighbor_query_node.h" +#include "phrase_query_node.h" +#include "query.h" #include "regexp_term.h" +#include "same_element_query_node.h" #include <vespa/searchlib/parsequery/stackdumpiterator.h> #include <vespa/searchlib/query/streaming/dot_product_term.h> #include <vespa/searchlib/query/streaming/in_term.h> @@ -47,6 +51,7 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor case ParseItem::ITEM_SAME_ELEMENT: case ParseItem::ITEM_NEAR: case ParseItem::ITEM_ONEAR: + case ParseItem::ITEM_RANK: { qn = QueryConnector::create(type); if (qn) { @@ -147,17 +152,16 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor } else { Normalizing normalize_mode = factory.normalizing_mode(ssIndex); std::unique_ptr<QueryTerm> qt; - if (sTerm != TermType::REGEXP) { - qt = std::make_unique<QueryTerm>(factory.create(), ssTerm, ssIndex, sTerm, normalize_mode); - } else { + if (sTerm == TermType::REGEXP) { qt = std::make_unique<RegexpTerm>(factory.create(), ssTerm, ssIndex, TermType::REGEXP, normalize_mode); + } else if (sTerm == TermType::FUZZYTERM) { + qt = std::make_unique<FuzzyTerm>(factory.create(), ssTerm, ssIndex, TermType::FUZZYTERM, normalize_mode, + queryRep.getFuzzyMaxEditDistance(), queryRep.getFuzzyPrefixLength()); + } else [[likely]] { + qt = std::make_unique<QueryTerm>(factory.create(), ssTerm, ssIndex, sTerm, normalize_mode); } qt->setWeight(queryRep.GetWeight()); qt->setUniqueId(queryRep.getUniqueId()); - if (qt->isFuzzy()) { - qt->setFuzzyMaxEditDistance(queryRep.getFuzzyMaxEditDistance()); - qt->setFuzzyPrefixLength(queryRep.getFuzzyPrefixLength()); - } if (allowRewrite && possibleFloat(*qt, ssTerm) && factory.allow_float_terms_rewrite(ssIndex)) { auto phrase = std::make_unique<PhraseQueryNode>(); auto dotPos = ssTerm.find('.'); @@ -173,17 +177,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor } } break; - case ParseItem::ITEM_RANK: - { - if (arity >= 1) { - queryRep.next(); - qn = Build(parent, factory, queryRep, false); - for (uint32_t skipCount = arity-1; (skipCount > 0) && queryRep.next(); skipCount--) { - skipCount += queryRep.getArity(); - } - } - } - break; case ParseItem::ITEM_STRING_IN: qn = std::make_unique<InTerm>(factory.create(), queryRep.getIndexName(), queryRep.get_terms(), factory.normalizing_mode(queryRep.getIndexName())); diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp index d72a3371846..af8ce7c9994 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp @@ -1,5 +1,6 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "querynoderesultbase.h" +#include <ostream> namespace search::streaming { diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp index 3e05d381ee2..e0b78633af3 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp @@ -162,9 +162,9 @@ void QueryTerm::resizeFieldId(size_t fieldNo) } } -void QueryTerm::add(unsigned pos, unsigned context, uint32_t elemId, int32_t weight_) +void QueryTerm::add(uint32_t field_id, uint32_t element_id, int32_t element_weight, uint32_t position) { - _hitList.emplace_back(pos, context, elemId, weight_); + _hitList.emplace_back(field_id, element_id, element_weight, position); } NearestNeighborQueryNode* @@ -185,4 +185,10 @@ QueryTerm::as_regexp_term() noexcept return nullptr; } +FuzzyTerm* +QueryTerm::as_fuzzy_term() noexcept +{ + return nullptr; +} + } diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h index cd2bdd7eaec..627fae0532d 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h +++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h @@ -11,6 +11,7 @@ namespace search::streaming { +class FuzzyTerm; class NearestNeighborQueryNode; class MultiTerm; class RegexpTerm; @@ -64,7 +65,7 @@ public: QueryTerm & operator = (QueryTerm &&) = delete; ~QueryTerm() override; bool evaluate() const override; - const HitList & evaluateHits(HitList & hl) const override; + const HitList & evaluateHits(HitList & hl) const final override; void reset() override; void getLeaves(QueryTermList & tl) override; void getLeaves(ConstQueryTermList & tl) const override; @@ -73,7 +74,7 @@ public: /// Gives you all phrases of this tree. Indicating that they are all const. void getPhrases(ConstQueryNodeRefList & tl) const override; - void add(unsigned pos, unsigned context, uint32_t elemId, int32_t weight); + void add(uint32_t field_id, uint32_t element_id, int32_t element_weight, uint32_t position); EncodingBitMap encoding() const { return _encoding; } size_t termLen() const { return getTermLen(); } const string & index() const { return _index; } @@ -95,6 +96,7 @@ public: virtual NearestNeighborQueryNode* as_nearest_neighbor_query_node() noexcept; virtual MultiTerm* as_multi_term() noexcept; virtual RegexpTerm* as_regexp_term() noexcept; + virtual FuzzyTerm* as_fuzzy_term() noexcept; protected: using QueryNodeResultBaseContainer = std::unique_ptr<QueryNodeResultBase>; string _index; diff --git a/searchlib/src/vespa/searchlib/query/streaming/same_element_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/same_element_query_node.cpp new file mode 100644 index 00000000000..49d5fb0f9fb --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/same_element_query_node.cpp @@ -0,0 +1,65 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "same_element_query_node.h" +#include <cassert> + +namespace search::streaming { + +bool +SameElementQueryNode::evaluate() const { + HitList hl; + return ! evaluateHits(hl).empty(); +} + +void +SameElementQueryNode::addChild(QueryNode::UP child) { + assert(dynamic_cast<const QueryTerm *>(child.get()) != nullptr); + AndQueryNode::addChild(std::move(child)); +} + +const HitList & +SameElementQueryNode::evaluateHits(HitList & hl) const +{ + hl.clear(); + if ( !AndQueryNode::evaluate()) return hl; + + HitList tmpHL; + const auto & children = getChildren(); + unsigned int numFields = children.size(); + unsigned int currMatchCount = 0; + std::vector<unsigned int> indexVector(numFields, 0); + auto curr = static_cast<const QueryTerm *> (children[currMatchCount].get()); + bool exhausted( curr->evaluateHits(tmpHL).empty()); + for (; !exhausted; ) { + auto next = static_cast<const QueryTerm *>(children[currMatchCount+1].get()); + unsigned int & currIndex = indexVector[currMatchCount]; + unsigned int & nextIndex = indexVector[currMatchCount+1]; + + const auto & currHit = curr->evaluateHits(tmpHL)[currIndex]; + uint32_t currElemId = currHit.element_id(); + + const HitList & nextHL = next->evaluateHits(tmpHL); + + size_t nextIndexMax = nextHL.size(); + while ((nextIndex < nextIndexMax) && (nextHL[nextIndex].element_id() < currElemId)) { + nextIndex++; + } + if ((nextIndex < nextIndexMax) && (nextHL[nextIndex].element_id() == currElemId)) { + currMatchCount++; + if ((currMatchCount+1) == numFields) { + Hit h = nextHL[indexVector[currMatchCount]]; + hl.emplace_back(h.field_id(), h.element_id(), h.element_weight(), 0); + currMatchCount = 0; + indexVector[0]++; + } + } else { + currMatchCount = 0; + indexVector[currMatchCount]++; + } + curr = static_cast<const QueryTerm *>(children[currMatchCount].get()); + exhausted = (nextIndex >= nextIndexMax) || (indexVector[currMatchCount] >= curr->evaluateHits(tmpHL).size()); + } + return hl; +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/same_element_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/same_element_query_node.h new file mode 100644 index 00000000000..50a479eda94 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/same_element_query_node.h @@ -0,0 +1,22 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "query.h" + +namespace search::streaming { + +/** + N-ary Same element operator. All terms must be within the same element. +*/ +class SameElementQueryNode : public AndQueryNode +{ +public: + SameElementQueryNode() noexcept : AndQueryNode("SAME_ELEMENT") { } + bool evaluate() const override; + const HitList & evaluateHits(HitList & hl) const override; + bool isFlattenable(ParseItem::ItemType type) const override { return type == ParseItem::ITEM_NOT; } + void addChild(QueryNode::UP child) override; +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp index 90d0be5d43c..d2d706eef3d 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp @@ -25,7 +25,7 @@ WeightedSetTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchDat for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - scores[hit.context()].emplace_back(term->weight().percent()); + scores[hit.field_id()].emplace_back(term->weight().percent()); } } auto num_fields = td.numFields(); diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h index 4473e0fa45b..3d8a5fba843 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h +++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h @@ -10,7 +10,6 @@ namespace search::streaming { * A weighted set query term for streaming search. */ class WeightedSetTerm : public MultiTerm { - double _score_threshold; public: WeightedSetTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string& index, uint32_t num_terms); ~WeightedSetTerm() override; diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp index 6ca072d6dc7..2b9b45c990e 100644 --- a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp @@ -89,7 +89,6 @@ Blueprint::sat_sum(const std::vector<HitEstimate> &data, uint32_t docid_limit) Blueprint::State::State() noexcept : _fields(), - _relative_estimate(0.0), _estimateHits(0), _tree_size(1), _estimateEmpty(true), @@ -106,7 +105,6 @@ Blueprint::State::State(FieldSpecBase field) noexcept Blueprint::State::State(FieldSpecBaseList fields_in) noexcept : _fields(std::move(fields_in)), - _relative_estimate(0.0), _estimateHits(0), _tree_size(1), _estimateEmpty(true), @@ -120,7 +118,9 @@ Blueprint::State::~State() = default; Blueprint::Blueprint() noexcept : _parent(nullptr), - _cost(1.0), + _relative_estimate(0.0), + _cost(0.0), + _strict_cost(0.0), _sourceId(0xffffffff), _docid_limit(0), _frozen(false) @@ -130,15 +130,15 @@ Blueprint::Blueprint() noexcept Blueprint::~Blueprint() = default; Blueprint::UP -Blueprint::optimize(Blueprint::UP bp, bool sort_by_cost) { +Blueprint::optimize(Blueprint::UP bp) { Blueprint *root = bp.release(); - root->optimize(root, OptimizePass::FIRST, sort_by_cost); - root->optimize(root, OptimizePass::LAST, sort_by_cost); + root->optimize(root, OptimizePass::FIRST); + root->optimize(root, OptimizePass::LAST); return Blueprint::UP(root); } void -Blueprint::optimize_self(OptimizePass, bool) +Blueprint::optimize_self(OptimizePass) { } @@ -353,12 +353,13 @@ Blueprint::visitMembers(vespalib::ObjectVisitor &visitor) const visitor.openStruct("estimate", "HitEstimate"); visitor.visitBool("empty", state.estimate().empty); visitor.visitInt("estHits", state.estimate().estHits); - visitor.visitFloat("relative_estimate", state.relative_estimate()); visitor.visitInt("cost_tier", state.cost_tier()); visitor.visitInt("tree_size", state.tree_size()); visitor.visitBool("allow_termwise_eval", state.allow_termwise_eval()); visitor.closeStruct(); + visitor.visitFloat("relative_estimate", _relative_estimate); visitor.visitFloat("cost", _cost); + visitor.visitFloat("strict_cost", _strict_cost); visitor.visitInt("sourceId", _sourceId); visitor.visitInt("docid_limit", _docid_limit); } @@ -518,7 +519,6 @@ IntermediateBlueprint::calculateState() const { State state(exposeFields()); state.estimate(calculateEstimate()); - state.relative_estimate(calculate_relative_estimate()); state.cost_tier(calculate_cost_tier()); state.allow_termwise_eval(infer_allow_termwise_eval()); state.want_global_filter(infer_want_global_filter()); @@ -548,25 +548,33 @@ IntermediateBlueprint::should_do_termwise_eval(const UnpackInfo &unpack, double } void -IntermediateBlueprint::optimize(Blueprint* &self, OptimizePass pass, bool sort_by_cost) +IntermediateBlueprint::optimize(Blueprint* &self, OptimizePass pass) { assert(self == this); - if (should_optimize_children()) { - for (auto &child : _children) { - auto *child_ptr = child.release(); - child_ptr->optimize(child_ptr, pass, sort_by_cost); - child.reset(child_ptr); - } + for (auto &child : _children) { + auto *child_ptr = child.release(); + child_ptr->optimize(child_ptr, pass); + child.reset(child_ptr); } - optimize_self(pass, sort_by_cost); + optimize_self(pass); if (pass == OptimizePass::LAST) { - sort(_children, sort_by_cost); + set_relative_estimate(calculate_relative_estimate()); set_cost(calculate_cost()); + set_strict_cost(calculate_strict_cost()); } maybe_eliminate_self(self, get_replacement()); } void +IntermediateBlueprint::sort(bool strict, bool sort_by_cost) +{ + sort(_children, strict, sort_by_cost); + for (size_t i = 0; i < _children.size(); ++i) { + _children[i]->sort(strict && inheritStrict(i), sort_by_cost); + } +} + +void IntermediateBlueprint::set_global_filter(const GlobalFilter &global_filter, double estimated_hit_ratio) { for (auto & child : _children) { @@ -710,23 +718,32 @@ IntermediateBlueprint::calculateUnpackInfo(const fef::MatchData & md) const //----------------------------------------------------------------------------- -void -LeafBlueprint::setDocIdLimit(uint32_t limit) noexcept { - Blueprint::setDocIdLimit(limit); - _state.relative_estimate(calculate_relative_estimate()); - notifyChange(); -} - double LeafBlueprint::calculate_relative_estimate() const { double rel_est = abs_to_rel_est(_state.estimate().estHits, get_docid_limit()); if (rel_est > 0.9) { // Assume we do not really know how much we are matching when - // we claim to match 'everything' + // we claim to match 'everything'. Also assume we are not able + // to skip documents efficiently when strict. + _can_skip = false; return 0.5; + } else { + _can_skip = true; + return rel_est; } - return rel_est; +} + +double +LeafBlueprint::calculate_cost() const +{ + return 1.0; +} + +double +LeafBlueprint::calculate_strict_cost() const +{ + return _can_skip ? estimate() * cost() : cost(); } void @@ -758,14 +775,24 @@ LeafBlueprint::getRange(vespalib::string &, vespalib::string &) const { } void -LeafBlueprint::optimize(Blueprint* &self, OptimizePass pass, bool sort_by_cost) +LeafBlueprint::optimize(Blueprint* &self, OptimizePass pass) { assert(self == this); - optimize_self(pass, sort_by_cost); + optimize_self(pass); + if (pass == OptimizePass::LAST) { + set_relative_estimate(calculate_relative_estimate()); + set_cost(calculate_cost()); + set_strict_cost(calculate_strict_cost()); + } maybe_eliminate_self(self, get_replacement()); } void +LeafBlueprint::sort(bool, bool) +{ +} + +void LeafBlueprint::set_cost_tier(uint32_t value) { assert(value < 0x100); diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.h b/searchlib/src/vespa/searchlib/queryeval/blueprint.h index d998c2e343e..e080d667dfa 100644 --- a/searchlib/src/vespa/searchlib/queryeval/blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.h @@ -53,7 +53,7 @@ public: using SearchIteratorUP = std::unique_ptr<SearchIterator>; enum class OptimizePass { FIRST, LAST }; - + struct HitEstimate { uint32_t estHits; bool empty; @@ -75,7 +75,6 @@ public: { private: FieldSpecBaseList _fields; - double _relative_estimate; uint32_t _estimateHits; uint32_t _tree_size : 20; bool _estimateEmpty : 1; @@ -111,9 +110,6 @@ public: return nullptr; } - void relative_estimate(double value) noexcept { _relative_estimate = value; } - double relative_estimate() const noexcept { return _relative_estimate; } - void estimate(HitEstimate est) noexcept { _estimateHits = est.estHits; _estimateEmpty = est.empty; @@ -182,7 +178,9 @@ public: private: Blueprint *_parent; + double _relative_estimate; double _cost; + double _strict_cost; uint32_t _sourceId; uint32_t _docid_limit; bool _frozen; @@ -198,7 +196,9 @@ protected: _frozen = true; } + void set_relative_estimate(double value) noexcept { _relative_estimate = value; } void set_cost(double value) noexcept { _cost = value; } + void set_strict_cost(double value) noexcept { _strict_cost = value; } public: class IPredicate { @@ -222,19 +222,22 @@ public: Blueprint *getParent() const noexcept { return _parent; } bool has_parent() const { return (_parent != nullptr); } - double cost() const noexcept { return _cost; } - Blueprint &setSourceId(uint32_t sourceId) noexcept { _sourceId = sourceId; return *this; } uint32_t getSourceId() const noexcept { return _sourceId; } virtual void setDocIdLimit(uint32_t limit) noexcept { _docid_limit = limit; } uint32_t get_docid_limit() const noexcept { return _docid_limit; } - static Blueprint::UP optimize(Blueprint::UP bp, bool sort_by_cost); - virtual void optimize(Blueprint* &self, OptimizePass pass, bool sort_by_cost) = 0; - virtual void optimize_self(OptimizePass pass, bool sort_by_cost); + static Blueprint::UP optimize(Blueprint::UP bp); + virtual void sort(bool strict, bool sort_by_cost) = 0; + static Blueprint::UP optimize_and_sort(Blueprint::UP bp, bool strict, bool sort_by_cost) { + auto result = optimize(std::move(bp)); + result->sort(strict, sort_by_cost); + return result; + } + virtual void optimize(Blueprint* &self, OptimizePass pass) = 0; + virtual void optimize_self(OptimizePass pass); virtual Blueprint::UP get_replacement(); - virtual bool should_optimize_children() const { return true; } virtual bool supports_termwise_children() const { return false; } virtual bool always_needs_unpack() const { return false; } @@ -254,9 +257,27 @@ public: const Blueprint &root() const; double hit_ratio() const { return getState().hit_ratio(_docid_limit); } - double estimate() const { return getState().relative_estimate(); } - virtual double calculate_relative_estimate() const = 0; + // The flow statistics for a blueprint is calculated during the + // LAST optimize pass (just prior to sorting). The relative + // estimate may be used to calculate the costs and the non-strict + // cost may be used to calculate the strict cost. After being + // calculated, each value is available through a simple accessor + // function. Note that these values may not be available for + // blueprints used inside complex leafs (this case will probably + // be solved using custom flow adapters that has knowledge of + // docid limit). + // + // 'estimate': relative estimate in the range [0,1] + // 'cost': per-document cost of non-strict evaluation + // 'strict_cost': per-document cost of strict evaluation + double estimate() const noexcept { return _relative_estimate; } + double cost() const noexcept { return _cost; } + double strict_cost() const noexcept { return _strict_cost; } + virtual double calculate_relative_estimate() const = 0; + virtual double calculate_cost() const = 0; + virtual double calculate_strict_cost() const = 0; + virtual void fetchPostings(const ExecuteInfo &execInfo) = 0; virtual void freeze() = 0; bool frozen() const { return _frozen; } @@ -360,7 +381,8 @@ public: void setDocIdLimit(uint32_t limit) noexcept final; - void optimize(Blueprint* &self, OptimizePass pass, bool sort_by_cost) final; + void optimize(Blueprint* &self, OptimizePass pass) final; + void sort(bool strict, bool sort_by_cost) override; void set_global_filter(const GlobalFilter &global_filter, double estimated_hit_ratio) override; IndexList find(const IPredicate & check) const; @@ -374,10 +396,9 @@ public: Blueprint::UP removeLastChild() { return removeChild(childCnt() - 1); } SearchIteratorUP createSearch(fef::MatchData &md, bool strict) const override; - virtual double calculate_cost() const = 0; virtual HitEstimate combine(const std::vector<HitEstimate> &data) const = 0; virtual FieldSpecBaseList exposeFields() const = 0; - virtual void sort(Children &children, bool sort_by_cost) const = 0; + virtual void sort(Children &children, bool strict, bool sort_by_cost) const = 0; virtual bool inheritStrict(size_t i) const = 0; virtual SearchIteratorUP createIntermediateSearch(MultiSearch::Children subSearches, @@ -396,11 +417,12 @@ class LeafBlueprint : public Blueprint { private: State _state; + mutable bool _can_skip = true; protected: - void optimize(Blueprint* &self, OptimizePass pass, bool sort_by_cost) final; + void optimize(Blueprint* &self, OptimizePass pass) final; + void sort(bool strict, bool sort_by_cost) override; void setEstimate(HitEstimate est) { _state.estimate(est); - _state.relative_estimate(calculate_relative_estimate()); notifyChange(); } void set_cost_tier(uint32_t value); @@ -431,9 +453,9 @@ protected: public: ~LeafBlueprint() override = default; const State &getState() const final { return _state; } - void setDocIdLimit(uint32_t limit) noexcept final; - using Blueprint::set_cost; double calculate_relative_estimate() const override; + double calculate_cost() const override; + double calculate_strict_cost() const override; void fetchPostings(const ExecuteInfo &execInfo) override; void freeze() final; SearchIteratorUP createSearch(fef::MatchData &md, bool strict) const override; diff --git a/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h b/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h index e49fcbcb5bc..c74c2d4e9a7 100644 --- a/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h +++ b/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h @@ -27,6 +27,7 @@ protected: public: static constexpr bool filter_search = false; static constexpr bool require_btree_iterators = true; + static constexpr bool supports_hash_filter = false; // TODO: use MultiSearch::Children to pass ownership static SearchIterator::UP create(const std::vector<SearchIterator*> &children, diff --git a/searchlib/src/vespa/searchlib/queryeval/flow.h b/searchlib/src/vespa/searchlib/queryeval/flow.h index 86ce6f8b93b..b90321581b5 100644 --- a/searchlib/src/vespa/searchlib/queryeval/flow.h +++ b/searchlib/src/vespa/searchlib/queryeval/flow.h @@ -13,19 +13,11 @@ namespace search::queryeval { namespace flow { // the default adapter expects the shape of std::unique_ptr<Blueprint> -// with respect to estimate, cost and (coming soon) strict_cost. +// with respect to estimate, cost and strict_cost. struct DefaultAdapter { double estimate(const auto &child) const noexcept { return child->estimate(); } double cost(const auto &child) const noexcept { return child->cost(); } - // Estimate the per-document cost of strict evaluation of this - // child. This will typically be something like (estimate() * - // cost()) for leafs with posting lists. OR will aggregate strict - // cost by calculating the minimal OR flow of strict child - // costs. AND will aggregate strict cost by calculating the - // minimal AND flow where the cost of the first child is - // substituted by its strict cost. This value is currently not - // available in Blueprints. - double strict_cost(const auto &child) const noexcept { return child->cost(); } + double strict_cost(const auto &child) const noexcept { return child->strict_cost(); } }; template <typename ADAPTER, typename T> @@ -154,8 +146,6 @@ struct FlowMixin { static double cost_of(const auto &children, bool strict) { return cost_of(flow::DefaultAdapter(), children, strict); } - // TODO: remove - static double cost_of(const auto &children) { return cost_of(children, false); } }; class AndFlow : public FlowMixin<AndFlow> { @@ -190,9 +180,8 @@ public: children[0] = std::move(the_one); } } - // TODO: add strict - static void sort(auto &children) { - sort(flow::DefaultAdapter(), children, false); + static void sort(auto &children, bool strict) { + sort(flow::DefaultAdapter(), children, strict); } }; @@ -224,9 +213,8 @@ public: flow::sort<flow::MinOrCost>(adapter, children); } } - // TODO: add strict - static void sort(auto &children) { - sort(flow::DefaultAdapter(), children, false); + static void sort(auto &children, bool strict) { + sort(flow::DefaultAdapter(), children, strict); } }; @@ -254,9 +242,8 @@ public: static void sort(auto adapter, auto &children, bool) { flow::sort_partial<flow::MinOrCost>(adapter, children, 1); } - // TODO: add strict - static void sort(auto &children) { - sort(flow::DefaultAdapter(), children, false); + static void sort(auto &children, bool strict) { + sort(flow::DefaultAdapter(), children, strict); } }; diff --git a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp index e60fe3d3f85..8cabe189b0e 100644 --- a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp @@ -33,7 +33,7 @@ size_t lookup_create_source(std::vector<std::unique_ptr<CombineType> > &sources, } template <typename CombineType> -void optimize_source_blenders(IntermediateBlueprint &self, size_t begin_idx, bool sort_by_cost) { +void optimize_source_blenders(IntermediateBlueprint &self, size_t begin_idx) { std::vector<size_t> source_blenders; const SourceBlenderBlueprint * reference = nullptr; for (size_t i = begin_idx; i < self.childCnt(); ++i) { @@ -63,7 +63,7 @@ void optimize_source_blenders(IntermediateBlueprint &self, size_t begin_idx, boo top->addChild(std::move(sources.back())); sources.pop_back(); } - blender_up = Blueprint::optimize(std::move(blender_up), sort_by_cost); + blender_up = Blueprint::optimize(std::move(blender_up)); self.addChild(std::move(blender_up)); } } @@ -87,15 +87,21 @@ need_normal_features_for_children(const IntermediateBlueprint &blueprint, fef::M //----------------------------------------------------------------------------- double +AndNotBlueprint::calculate_relative_estimate() const +{ + return AndNotFlow::estimate_of(get_children()); +} + +double AndNotBlueprint::calculate_cost() const { - return AndNotFlow::cost_of(get_children()); + return AndNotFlow::cost_of(get_children(), false); } double -AndNotBlueprint::calculate_relative_estimate() const +AndNotBlueprint::calculate_strict_cost() const { - return AndNotFlow::estimate_of(get_children()); + return AndNotFlow::cost_of(get_children(), true); } Blueprint::HitEstimate @@ -114,7 +120,7 @@ AndNotBlueprint::exposeFields() const } void -AndNotBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) +AndNotBlueprint::optimize_self(OptimizePass pass) { if (childCnt() == 0) { return; @@ -152,7 +158,7 @@ AndNotBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) } } if (pass == OptimizePass::LAST) { - optimize_source_blenders<OrBlueprint>(*this, 1, sort_by_cost); + optimize_source_blenders<OrBlueprint>(*this, 1); } } @@ -166,10 +172,10 @@ AndNotBlueprint::get_replacement() } void -AndNotBlueprint::sort(Children &children, bool sort_by_cost) const +AndNotBlueprint::sort(Children &children, bool strict, bool sort_by_cost) const { if (sort_by_cost) { - AndNotFlow::sort(children); + AndNotFlow::sort(children, strict); } else { if (children.size() > 2) { std::sort(children.begin() + 1, children.end(), TieredGreaterEstimate()); @@ -213,13 +219,18 @@ AndNotBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) co //----------------------------------------------------------------------------- double +AndBlueprint::calculate_relative_estimate() const { + return AndFlow::estimate_of(get_children()); +} + +double AndBlueprint::calculate_cost() const { - return AndFlow::cost_of(get_children()); + return AndFlow::cost_of(get_children(), false); } double -AndBlueprint::calculate_relative_estimate() const { - return AndFlow::estimate_of(get_children()); +AndBlueprint::calculate_strict_cost() const { + return AndFlow::cost_of(get_children(), true); } Blueprint::HitEstimate @@ -235,7 +246,7 @@ AndBlueprint::exposeFields() const } void -AndBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) +AndBlueprint::optimize_self(OptimizePass pass) { if (pass == OptimizePass::FIRST) { for (size_t i = 0; i < childCnt(); ++i) { @@ -248,7 +259,7 @@ AndBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) } } if (pass == OptimizePass::LAST) { - optimize_source_blenders<AndBlueprint>(*this, 0, sort_by_cost); + optimize_source_blenders<AndBlueprint>(*this, 0); } } @@ -262,10 +273,10 @@ AndBlueprint::get_replacement() } void -AndBlueprint::sort(Children &children, bool sort_by_cost) const +AndBlueprint::sort(Children &children, bool strict, bool sort_by_cost) const { if (sort_by_cost) { - AndFlow::sort(children); + AndFlow::sort(children, strict); } else { std::sort(children.begin(), children.end(), TieredLessEstimate()); } @@ -322,13 +333,18 @@ OrBlueprint::computeNextHitRate(const Blueprint & child, double hit_rate) const OrBlueprint::~OrBlueprint() = default; double +OrBlueprint::calculate_relative_estimate() const { + return OrFlow::estimate_of(get_children()); +} + +double OrBlueprint::calculate_cost() const { - return OrFlow::cost_of(get_children()); + return OrFlow::cost_of(get_children(), false); } double -OrBlueprint::calculate_relative_estimate() const { - return OrFlow::estimate_of(get_children()); +OrBlueprint::calculate_strict_cost() const { + return OrFlow::cost_of(get_children(), true); } Blueprint::HitEstimate @@ -344,7 +360,7 @@ OrBlueprint::exposeFields() const } void -OrBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) +OrBlueprint::optimize_self(OptimizePass pass) { if (pass == OptimizePass::FIRST) { for (size_t i = 0; (childCnt() > 1) && (i < childCnt()); ++i) { @@ -359,7 +375,7 @@ OrBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) } } if (pass == OptimizePass::LAST) { - optimize_source_blenders<OrBlueprint>(*this, 0, sort_by_cost); + optimize_source_blenders<OrBlueprint>(*this, 0); } } @@ -373,10 +389,10 @@ OrBlueprint::get_replacement() } void -OrBlueprint::sort(Children &children, bool sort_by_cost) const +OrBlueprint::sort(Children &children, bool strict, bool sort_by_cost) const { if (sort_by_cost) { - OrFlow::sort(children); + OrFlow::sort(children, strict); } else { std::sort(children.begin(), children.end(), TieredGreaterEstimate()); } @@ -427,17 +443,22 @@ OrBlueprint::calculate_cost_tier() const WeakAndBlueprint::~WeakAndBlueprint() = default; double -WeakAndBlueprint::calculate_cost() const { - return OrFlow::cost_of(get_children()); -} - -double WeakAndBlueprint::calculate_relative_estimate() const { double child_est = OrFlow::estimate_of(get_children()); double my_est = abs_to_rel_est(_n, get_docid_limit()); return std::min(my_est, child_est); } +double +WeakAndBlueprint::calculate_cost() const { + return OrFlow::cost_of(get_children(), false); +} + +double +WeakAndBlueprint::calculate_strict_cost() const { + return OrFlow::cost_of(get_children(), true); +} + Blueprint::HitEstimate WeakAndBlueprint::combine(const std::vector<HitEstimate> &data) const { @@ -456,7 +477,7 @@ WeakAndBlueprint::exposeFields() const } void -WeakAndBlueprint::sort(Children &, bool) const +WeakAndBlueprint::sort(Children &, bool, bool) const { // order needs to stay the same as _weights } @@ -498,13 +519,18 @@ WeakAndBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) c //----------------------------------------------------------------------------- double +NearBlueprint::calculate_relative_estimate() const { + return AndFlow::estimate_of(get_children()); +} + +double NearBlueprint::calculate_cost() const { - return AndFlow::cost_of(get_children()) + childCnt() * 1.0; + return AndFlow::cost_of(get_children(), false) + childCnt() * estimate(); } double -NearBlueprint::calculate_relative_estimate() const { - return AndFlow::estimate_of(get_children()); +NearBlueprint::calculate_strict_cost() const { + return AndFlow::cost_of(get_children(), true) + childCnt() * estimate(); } Blueprint::HitEstimate @@ -520,10 +546,10 @@ NearBlueprint::exposeFields() const } void -NearBlueprint::sort(Children &children, bool sort_by_cost) const +NearBlueprint::sort(Children &children, bool strict, bool sort_by_cost) const { if (sort_by_cost) { - AndFlow::sort(children); + AndFlow::sort(children, strict); } else { std::sort(children.begin(), children.end(), TieredLessEstimate()); } @@ -565,13 +591,18 @@ NearBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) cons //----------------------------------------------------------------------------- double +ONearBlueprint::calculate_relative_estimate() const { + return AndFlow::estimate_of(get_children()); +} + +double ONearBlueprint::calculate_cost() const { - return AndFlow::cost_of(get_children()) + (childCnt() * 1.0); + return AndFlow::cost_of(get_children(), false) + childCnt() * estimate(); } double -ONearBlueprint::calculate_relative_estimate() const { - return AndFlow::estimate_of(get_children()); +ONearBlueprint::calculate_strict_cost() const { + return AndFlow::cost_of(get_children(), true) + childCnt() * estimate(); } Blueprint::HitEstimate @@ -587,7 +618,7 @@ ONearBlueprint::exposeFields() const } void -ONearBlueprint::sort(Children &, bool) const +ONearBlueprint::sort(Children &, bool, bool) const { // ordered near cannot sort children here } @@ -630,13 +661,18 @@ ONearBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) con //----------------------------------------------------------------------------- double +RankBlueprint::calculate_relative_estimate() const { + return (childCnt() == 0) ? 0.0 : getChild(0).estimate(); +} + +double RankBlueprint::calculate_cost() const { - return (childCnt() == 0) ? 1.0 : getChild(0).cost(); + return (childCnt() == 0) ? 0.0 : getChild(0).cost(); } double -RankBlueprint::calculate_relative_estimate() const { - return (childCnt() == 0) ? 0.0 : getChild(0).estimate(); +RankBlueprint::calculate_strict_cost() const { + return (childCnt() == 0) ? 0.0 : getChild(0).strict_cost(); } Blueprint::HitEstimate @@ -655,7 +691,7 @@ RankBlueprint::exposeFields() const } void -RankBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) +RankBlueprint::optimize_self(OptimizePass pass) { if (pass == OptimizePass::FIRST) { for (size_t i = 1; i < childCnt(); ++i) { @@ -665,7 +701,7 @@ RankBlueprint::optimize_self(OptimizePass pass, bool sort_by_cost) } } if (pass == OptimizePass::LAST) { - optimize_source_blenders<OrBlueprint>(*this, 1, sort_by_cost); + optimize_source_blenders<OrBlueprint>(*this, 1); } } @@ -679,7 +715,7 @@ RankBlueprint::get_replacement() } void -RankBlueprint::sort(Children &, bool) const +RankBlueprint::sort(Children &, bool, bool) const { } @@ -731,8 +767,13 @@ SourceBlenderBlueprint::SourceBlenderBlueprint(const ISourceSelector &selector) SourceBlenderBlueprint::~SourceBlenderBlueprint() = default; double +SourceBlenderBlueprint::calculate_relative_estimate() const { + return OrFlow::estimate_of(get_children()); +} + +double SourceBlenderBlueprint::calculate_cost() const { - double my_cost = 1.0; + double my_cost = 0.0; for (const auto &child: get_children()) { my_cost = std::max(my_cost, child->cost()); } @@ -740,8 +781,12 @@ SourceBlenderBlueprint::calculate_cost() const { } double -SourceBlenderBlueprint::calculate_relative_estimate() const { - return OrFlow::estimate_of(get_children()); +SourceBlenderBlueprint::calculate_strict_cost() const { + double my_cost = 0.0; + for (const auto &child: get_children()) { + my_cost = std::max(my_cost, child->strict_cost()); + } + return my_cost; } Blueprint::HitEstimate @@ -757,7 +802,7 @@ SourceBlenderBlueprint::exposeFields() const } void -SourceBlenderBlueprint::sort(Children &, bool) const +SourceBlenderBlueprint::sort(Children &, bool, bool) const { } diff --git a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h index 620280e979b..368cbd35c69 100644 --- a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h +++ b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h @@ -15,14 +15,15 @@ class AndNotBlueprint : public IntermediateBlueprint { public: bool supports_termwise_children() const override { return true; } - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void optimize_self(OptimizePass pass, bool sort_by_cost) override; + void optimize_self(OptimizePass pass) override; AndNotBlueprint * asAndNot() noexcept final { return this; } Blueprint::UP get_replacement() override; - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, @@ -43,14 +44,15 @@ class AndBlueprint : public IntermediateBlueprint { public: bool supports_termwise_children() const override { return true; } - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void optimize_self(OptimizePass pass, bool sort_by_cost) override; + void optimize_self(OptimizePass pass) override; AndBlueprint * asAnd() noexcept final { return this; } Blueprint::UP get_replacement() override; - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, @@ -69,14 +71,15 @@ class OrBlueprint : public IntermediateBlueprint public: ~OrBlueprint() override; bool supports_termwise_children() const override { return true; } - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void optimize_self(OptimizePass pass, bool sort_by_cost) override; + void optimize_self(OptimizePass pass) override; OrBlueprint * asOr() noexcept final { return this; } Blueprint::UP get_replacement() override; - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, @@ -97,11 +100,12 @@ private: std::vector<uint32_t> _weights; public: - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_on_cost) const override; bool inheritStrict(size_t i) const override; bool always_needs_unpack() const override; WeakAndBlueprint * asWeakAnd() noexcept final { return this; } @@ -128,12 +132,12 @@ private: uint32_t _window; public: - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - bool should_optimize_children() const override { return false; } - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; SearchIteratorUP createSearch(fef::MatchData &md, bool strict) const override; SearchIterator::UP @@ -152,12 +156,12 @@ private: uint32_t _window; public: - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - bool should_optimize_children() const override { return false; } - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; SearchIteratorUP createSearch(fef::MatchData &md, bool strict) const override; SearchIterator::UP @@ -173,13 +177,14 @@ public: class RankBlueprint final : public IntermediateBlueprint { public: - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void optimize_self(OptimizePass pass, bool sort_by_cost) override; + void optimize_self(OptimizePass pass) override; Blueprint::UP get_replacement() override; - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; bool isRank() const noexcept final { return true; } SearchIterator::UP @@ -202,11 +207,12 @@ private: public: explicit SourceBlenderBlueprint(const ISourceSelector &selector) noexcept; ~SourceBlenderBlueprint() override; - double calculate_cost() const final; double calculate_relative_estimate() const final; + double calculate_cost() const final; + double calculate_strict_cost() const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void sort(Children &children, bool sort_by_cost) const override; + void sort(Children &children, bool strict, bool sort_by_cost) const override; bool inheritStrict(size_t i) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, diff --git a/searchlib/src/vespa/searchlib/queryeval/nearsearch.cpp b/searchlib/src/vespa/searchlib/queryeval/nearsearch.cpp index 1f83075b9fc..8fc7733f279 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearsearch.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearsearch.cpp @@ -3,7 +3,7 @@ #include <vespa/vespalib/objects/visit.h> #include <vespa/vespalib/util/priority_queue.h> #include <limits> -#include <set> +#include <map> #include <vespa/log/log.h> LOG_SETUP(".nearsearch"); @@ -16,13 +16,15 @@ using search::fef::TermFieldMatchDataArray; using search::fef::TermFieldMatchDataPositionKey; template<typename T> -void setup_fields(uint32_t window, std::vector<T> &matchers, const TermFieldMatchDataArray &in) { - std::set<uint32_t> fields; +void setup_fields(uint32_t window, std::vector<T> &matchers, const TermFieldMatchDataArray &in, uint32_t terms) { + std::map<uint32_t,uint32_t> fields; for (size_t i = 0; i < in.size(); ++i) { - fields.insert(in[i]->getFieldId()); + ++fields[in[i]->getFieldId()]; } - for (const auto& elem : fields) { - matchers.push_back(T(window, elem, in)); + for (auto [field, cnt]: fields) { + if (cnt == terms) { + matchers.push_back(T(window, field, in)); + } } } @@ -126,7 +128,7 @@ NearSearch::NearSearch(Children terms, : NearSearchBase(std::move(terms), data, window, strict), _matchers() { - setup_fields(window, _matchers, data); + setup_fields(window, _matchers, data, getChildren().size()); } namespace { @@ -227,7 +229,7 @@ ONearSearch::ONearSearch(Children terms, : NearSearchBase(std::move(terms), data, window, strict), _matchers() { - setup_fields(window, _matchers, data); + setup_fields(window, _matchers, data, getChildren().size()); } bool diff --git a/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.cpp index 500e9fe4dbb..96181377282 100644 --- a/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.cpp @@ -45,7 +45,7 @@ SameElementBlueprint::addTerm(Blueprint::UP term) } void -SameElementBlueprint::optimize_self(OptimizePass pass, bool) +SameElementBlueprint::optimize_self(OptimizePass pass) { if (pass == OptimizePass::LAST) { std::sort(_terms.begin(), _terms.end(), diff --git a/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.h index 06c20339e81..6a988e67149 100644 --- a/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.h @@ -34,7 +34,7 @@ public: // used by create visitor void addTerm(Blueprint::UP term); - void optimize_self(OptimizePass pass, bool sort_by_cost) override; + void optimize_self(OptimizePass pass) override; void fetchPostings(const ExecuteInfo &execInfo) override; std::unique_ptr<SameElementSearch> create_same_element_search(search::fef::TermFieldMatchData& tfmd, bool strict) const; diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp index 0c57b21aba6..0be014474d0 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp @@ -1,10 +1,13 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "weighted_set_term_search.h" +#include <vespa/searchcommon/attribute/i_search_context.h> +#include <vespa/searchcommon/attribute/iattributevector.h> +#include <vespa/searchlib/attribute/i_direct_posting_store.h> +#include <vespa/searchlib/attribute/multi_term_hash_filter.hpp> #include <vespa/searchlib/common/bitvector.h> -#include <vespa/searchlib/attribute/multi_term_or_filter_search.h> #include <vespa/vespalib/objects/visit.h> -#include <vespa/searchcommon/attribute/i_search_context.h> +#include <vespa/vespalib/stllike/hash_map.hpp> #include "iterator_pack.h" #include "blueprint.h" @@ -237,4 +240,98 @@ WeightedSetTermSearch::create(fef::TermFieldMatchData &tmd, return create_helper_resolve_pack<DocidWithWeightIterator, DocidWithWeightIteratorPack>(tmd, is_filter_search, std::move(weights), std::move(iterators)); } +namespace { + +class HashFilterWrapper { +protected: + const attribute::IAttributeVector& _attr; +public: + HashFilterWrapper(const attribute::IAttributeVector& attr) : _attr(attr) {} +}; + +template <bool unpack_weights_t> +class StringHashFilterWrapper : public HashFilterWrapper { +public: + using TokenT = attribute::IAttributeVector::EnumHandle; + static constexpr bool unpack_weights = unpack_weights_t; + StringHashFilterWrapper(const attribute::IAttributeVector& attr) + : HashFilterWrapper(attr) + {} + auto mapToken(const IDirectPostingStore::LookupResult& term, const IDirectPostingStore& store, vespalib::datastore::EntryRef dict_snapshot) const { + std::vector<TokenT> result; + store.collect_folded(term.enum_idx, dict_snapshot, [&](vespalib::datastore::EntryRef ref) { result.emplace_back(ref.ref()); }); + return result; + } + TokenT getToken(uint32_t docid) const { + return _attr.getEnum(docid); + } +}; + +template <bool unpack_weights_t> +class IntegerHashFilterWrapper : public HashFilterWrapper { +public: + using TokenT = attribute::IAttributeVector::largeint_t; + static constexpr bool unpack_weights = unpack_weights_t; + IntegerHashFilterWrapper(const attribute::IAttributeVector& attr) + : HashFilterWrapper(attr) + {} + auto mapToken(const IDirectPostingStore::LookupResult& term, + const IDirectPostingStore& store, + vespalib::datastore::EntryRef) const { + std::vector<TokenT> result; + result.emplace_back(store.get_integer_value(term.enum_idx)); + return result; + } + TokenT getToken(uint32_t docid) const { + return _attr.getInt(docid); + } +}; + +template <typename WrapperType> +SearchIterator::UP +create_hash_filter_helper(fef::TermFieldMatchData& tfmd, + const std::vector<int32_t>& weights, + const std::vector<IDirectPostingStore::LookupResult>& terms, + const attribute::IAttributeVector& attr, + const IDirectPostingStore& posting_store, + vespalib::datastore::EntryRef dict_snapshot) +{ + using FilterType = attribute::MultiTermHashFilter<WrapperType>; + typename FilterType::TokenMap tokens; + WrapperType wrapper(attr); + for (size_t i = 0; i < terms.size(); ++i) { + for (auto token : wrapper.mapToken(terms[i], posting_store, dict_snapshot)) { + tokens[token] = weights[i]; + } + } + return std::make_unique<FilterType>(tfmd, wrapper, std::move(tokens)); +} + +} + +SearchIterator::UP +WeightedSetTermSearch::create_hash_filter(search::fef::TermFieldMatchData& tmd, + bool is_filter_search, + const std::vector<int32_t>& weights, + const std::vector<IDirectPostingStore::LookupResult>& terms, + const attribute::IAttributeVector& attr, + const IDirectPostingStore& posting_store, + vespalib::datastore::EntryRef dict_snapshot) +{ + if (attr.isStringType()) { + if (is_filter_search) { + return create_hash_filter_helper<StringHashFilterWrapper<false>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } else { + return create_hash_filter_helper<StringHashFilterWrapper<true>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } + } else { + assert(attr.isIntegerType()); + if (is_filter_search) { + return create_hash_filter_helper<IntegerHashFilterWrapper<false>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } else { + return create_hash_filter_helper<IntegerHashFilterWrapper<true>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } + } +} + } diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h index 7e47928fb90..d078fd5babc 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h @@ -11,6 +11,8 @@ #include <variant> #include <vector> +namespace search::attribute { class IAttributeVector; } + namespace search::fef { class TermFieldMatchData; } namespace search::queryeval { @@ -30,6 +32,8 @@ public: static constexpr bool filter_search = false; // Whether this iterator requires btree iterators for all tokens/terms used by the operator. static constexpr bool require_btree_iterators = false; + // Whether this supports creating a hash filter iterator; + static constexpr bool supports_hash_filter = true; // TODO: pass ownership with unique_ptr static SearchIterator::UP create(const std::vector<SearchIterator *> &children, @@ -48,6 +52,14 @@ public: std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>> weights, std::vector<DocidWithWeightIterator> &&iterators); + static SearchIterator::UP create_hash_filter(search::fef::TermFieldMatchData& tmd, + bool is_filter_search, + const std::vector<int32_t>& weights, + const std::vector<IDirectPostingStore::LookupResult>& terms, + const attribute::IAttributeVector& attr, + const IDirectPostingStore& posting_store, + vespalib::datastore::EntryRef dictionary_snapshot); + // used during docsum fetching to identify matching elements // initRange must be called before use. // doSeek/doUnpack must not be called. diff --git a/storage/src/vespa/storage/config/stor-server.def b/storage/src/vespa/storage/config/stor-server.def index 0d877d33277..3d304dd1727 100644 --- a/storage/src/vespa/storage/config/stor-server.def +++ b/storage/src/vespa/storage/config/stor-server.def @@ -56,7 +56,7 @@ merge_throttling_policy.window_size_increment double default=2.0 ## > 0 explicit limit in bytes ## == 0 limit automatically deduced by content node ## < 0 unlimited (legacy behavior) -merge_throttling_memory_limit.max_usage_bytes long default=-1 +merge_throttling_memory_limit.max_usage_bytes long default=0 ## If merge_throttling_memory_limit.max_usage_bytes == 0, this factor is used ## as a multiplier to automatically deduce a memory limit for merges on the diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp index 93e35e4c6d2..c9518b29884 100644 --- a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp +++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp @@ -84,7 +84,7 @@ RankProcessorTest::test_unpack_match_data_for_term_node(bool interleaved_feature EXPECT_EQ(invalid_id, tfmd->getDocId()); RankProcessor::unpack_match_data(1, *md, *_query_wrapper); EXPECT_EQ(invalid_id, tfmd->getDocId()); - node->add(0, field_id, 0, 1); + node->add(field_id, 0, 1, 0); auto& field_info = node->getFieldInfo(field_id); field_info.setHitCount(mock_num_occs); field_info.setFieldLength(mock_field_length); diff --git a/streamingvisitors/src/tests/searcher/searcher_test.cpp b/streamingvisitors/src/tests/searcher/searcher_test.cpp index 24877866c1b..eb233db9632 100644 --- a/streamingvisitors/src/tests/searcher/searcher_test.cpp +++ b/streamingvisitors/src/tests/searcher/searcher_test.cpp @@ -3,6 +3,7 @@ #include <vespa/vespalib/testkit/testapp.h> #include <vespa/document/fieldvalue/fieldvalues.h> +#include <vespa/searchlib/query/streaming/fuzzy_term.h> #include <vespa/searchlib/query/streaming/regexp_term.h> #include <vespa/searchlib/query/streaming/queryterm.h> #include <vespa/vsm/searcher/boolfieldsearcher.h> @@ -18,10 +19,14 @@ #include <vespa/vsm/searcher/utf8suffixstringfieldsearcher.h> #include <vespa/vsm/searcher/tokenizereader.h> #include <vespa/vsm/vsm/snippetmodifier.h> +#include <concepts> +#include <charconv> +#include <stdexcept> using namespace document; using search::streaming::HitList; using search::streaming::QueryNodeResultFactory; +using search::streaming::FuzzyTerm; using search::streaming::RegexpTerm; using search::streaming::QueryTerm; using search::streaming::Normalizing; @@ -58,6 +63,46 @@ public: } }; +namespace { + +template <std::integral T> +std::string_view maybe_consume_into(std::string_view str, T& val_out) { + auto [ptr, ec] = std::from_chars(str.data(), str.data() + str.size(), val_out); + if (ec != std::errc()) { + return str; + } + return str.substr(ptr - str.data()); +} + +// Parse optional max edits and prefix lock length from term string. +// Syntax: +// "term" -> {2, 0, "term"} (default max edits & prefix length) +// "{1}term" -> {1, 0, "term"} +// "{1,3}term" -> {1, 3, "term"} +// +// Note: this is not a "proper" parser (it accepts empty numeric values); only for testing! +std::tuple<uint8_t, uint32_t, std::string_view> parse_fuzzy_params(std::string_view term) { + if (term.empty() || term[0] != '{') { + return {2, 0, term}; + } + uint8_t max_edits = 2; + uint32_t prefix_length = 0; + term = maybe_consume_into(term.substr(1), max_edits); + if (term.empty() || (term[0] != ',' && term[0] != '}')) { + throw std::invalid_argument("malformed fuzzy params at (or after) max_edits"); + } + if (term[0] == '}') { + return {max_edits, prefix_length, term.substr(1)}; + } + term = maybe_consume_into(term.substr(1), prefix_length); + if (term.empty() || term[0] != '}') { + throw std::invalid_argument("malformed fuzzy params at (or after) prefix_length"); + } + return {max_edits, prefix_length, term.substr(1)}; +} + +} + class Query { private: @@ -66,10 +111,14 @@ private: ParsedQueryTerm pqt = parseQueryTerm(term); ParsedTerm pt = parseTerm(pqt.second); std::string effective_index = pqt.first.empty() ? "index" : pqt.first; - if (pt.second != TermType::REGEXP) { - qtv.push_back(std::make_unique<QueryTerm>(eqnr.create(), pt.first, effective_index, pt.second, normalizing)); + if (pt.second == TermType::REGEXP) { + qtv.push_back(std::make_unique<RegexpTerm>(eqnr.create(), pt.first, effective_index, TermType::REGEXP, normalizing)); + } else if (pt.second == TermType::FUZZYTERM) { + auto [max_edits, prefix_length, actual_term] = parse_fuzzy_params(pt.first); + qtv.push_back(std::make_unique<FuzzyTerm>(eqnr.create(), vespalib::stringref(actual_term.data(), actual_term.size()), + effective_index, TermType::FUZZYTERM, normalizing, max_edits, prefix_length)); } else { - qtv.push_back(std::make_unique<RegexpTerm>(eqnr.create(), pt.first, effective_index, pt.second, normalizing)); + qtv.push_back(std::make_unique<QueryTerm>(eqnr.create(), pt.first, effective_index, pt.second, normalizing)); } } for (const auto & i : qtv) { @@ -100,6 +149,8 @@ public: return std::make_pair(term.substr(1, term.size() - 1), TermType::SUFFIXTERM); } else if (term[0] == '#') { // magic regex enabler return std::make_pair(term.substr(1), TermType::REGEXP); + } else if (term[0] == '%') { // equally magic fuzzy enabler + return std::make_pair(term.substr(1), TermType::FUZZYTERM); } else if (term[term.size() - 1] == '*') { return std::make_pair(term.substr(0, term.size() - 1), TermType::PREFIXTERM); } else { @@ -349,7 +400,8 @@ assertSearch(FieldSearcher & fs, const StringList & query, const FieldValue & fv EXPECT_EQUAL(hl.size(), exp[i].size()); ASSERT_TRUE(hl.size() == exp[i].size()); for (size_t j = 0; j < hl.size(); ++j) { - EXPECT_EQUAL((size_t)hl[j].pos(), exp[i][j]); + EXPECT_EQUAL(0u, hl[j].field_id()); + EXPECT_EQUAL((size_t)hl[j].position(), exp[i][j]); } } } @@ -477,31 +529,54 @@ testStrChrFieldSearcher(StrChrFieldSearcher & fs) return true; } - TEST("verify correct term parsing") { - ASSERT_TRUE(Query::parseQueryTerm("index:term").first == "index"); - ASSERT_TRUE(Query::parseQueryTerm("index:term").second == "term"); - ASSERT_TRUE(Query::parseQueryTerm("term").first.empty()); - ASSERT_TRUE(Query::parseQueryTerm("term").second == "term"); - ASSERT_TRUE(Query::parseTerm("*substr*").first == "substr"); - ASSERT_TRUE(Query::parseTerm("*substr*").second == TermType::SUBSTRINGTERM); - ASSERT_TRUE(Query::parseTerm("*suffix").first == "suffix"); - ASSERT_TRUE(Query::parseTerm("*suffix").second == TermType::SUFFIXTERM); - ASSERT_TRUE(Query::parseTerm("prefix*").first == "prefix"); - ASSERT_TRUE(Query::parseTerm("prefix*").second == TermType::PREFIXTERM); - ASSERT_TRUE(Query::parseTerm("#regex").first == "regex"); - ASSERT_TRUE(Query::parseTerm("#regex").second == TermType::REGEXP); - ASSERT_TRUE(Query::parseTerm("term").first == "term"); - ASSERT_TRUE(Query::parseTerm("term").second == TermType::WORD); - } - - TEST("suffix matching") { - EXPECT_EQUAL(assertMatchTermSuffix("a", "vespa"), true); - EXPECT_EQUAL(assertMatchTermSuffix("spa", "vespa"), true); - EXPECT_EQUAL(assertMatchTermSuffix("vespa", "vespa"), true); - EXPECT_EQUAL(assertMatchTermSuffix("vvespa", "vespa"), false); - EXPECT_EQUAL(assertMatchTermSuffix("fspa", "vespa"), false); - EXPECT_EQUAL(assertMatchTermSuffix("v", "vespa"), false); - } +TEST("parsing of test-only fuzzy term params can extract numeric values") { + uint8_t max_edits = 0; + uint32_t prefix_length = 1234; + std::string_view out; + + std::tie(max_edits, prefix_length, out) = parse_fuzzy_params("myterm"); + EXPECT_EQUAL(max_edits, 2u); + EXPECT_EQUAL(prefix_length, 0u); + EXPECT_EQUAL(out, "myterm"); + + std::tie(max_edits, prefix_length, out) = parse_fuzzy_params("{3}myterm"); + EXPECT_EQUAL(max_edits, 3u); + EXPECT_EQUAL(prefix_length, 0u); + EXPECT_EQUAL(out, "myterm"); + + std::tie(max_edits, prefix_length, out) = parse_fuzzy_params("{2,70}myterm"); + EXPECT_EQUAL(max_edits, 2u); + EXPECT_EQUAL(prefix_length, 70u); + EXPECT_EQUAL(out, "myterm"); +} + +TEST("verify correct term parsing") { + ASSERT_TRUE(Query::parseQueryTerm("index:term").first == "index"); + ASSERT_TRUE(Query::parseQueryTerm("index:term").second == "term"); + ASSERT_TRUE(Query::parseQueryTerm("term").first.empty()); + ASSERT_TRUE(Query::parseQueryTerm("term").second == "term"); + ASSERT_TRUE(Query::parseTerm("*substr*").first == "substr"); + ASSERT_TRUE(Query::parseTerm("*substr*").second == TermType::SUBSTRINGTERM); + ASSERT_TRUE(Query::parseTerm("*suffix").first == "suffix"); + ASSERT_TRUE(Query::parseTerm("*suffix").second == TermType::SUFFIXTERM); + ASSERT_TRUE(Query::parseTerm("prefix*").first == "prefix"); + ASSERT_TRUE(Query::parseTerm("prefix*").second == TermType::PREFIXTERM); + ASSERT_TRUE(Query::parseTerm("#regex").first == "regex"); + ASSERT_TRUE(Query::parseTerm("#regex").second == TermType::REGEXP); + ASSERT_TRUE(Query::parseTerm("%fuzzy").first == "fuzzy"); + ASSERT_TRUE(Query::parseTerm("%fuzzy").second == TermType::FUZZYTERM); + ASSERT_TRUE(Query::parseTerm("term").first == "term"); + ASSERT_TRUE(Query::parseTerm("term").second == TermType::WORD); +} + +TEST("suffix matching") { + EXPECT_EQUAL(assertMatchTermSuffix("a", "vespa"), true); + EXPECT_EQUAL(assertMatchTermSuffix("spa", "vespa"), true); + EXPECT_EQUAL(assertMatchTermSuffix("vespa", "vespa"), true); + EXPECT_EQUAL(assertMatchTermSuffix("vvespa", "vespa"), false); + EXPECT_EQUAL(assertMatchTermSuffix("fspa", "vespa"), false); + EXPECT_EQUAL(assertMatchTermSuffix("v", "vespa"), false); +} TEST("Test basic strchrfield searchers") { { @@ -654,6 +729,112 @@ TEST("utf8 flexible searcher handles regexes with explicit anchoring") { TEST_DO(assertString(fs, "#^foo$", "oo", Hits())); } +TEST("utf8 flexible searcher regex matching treats field as 1 word") { + UTF8FlexibleStringFieldSearcher fs(0); + // Match case + TEST_DO(assertFieldInfo(fs, "#.*", "foo bar baz", QTFieldInfo(0, 1, 1))); + // Mismatch case + TEST_DO(assertFieldInfo(fs, "#^zoid$", "foo bar baz", QTFieldInfo(0, 0, 1))); +} + +TEST("utf8 flexible searcher handles fuzzy search in uncased mode") { + UTF8FlexibleStringFieldSearcher fs(0); + // Term syntax (only applies to these tests): + // %{k}term => fuzzy match "term" with max edits k + // %{k,p}term => fuzzy match "term" with max edits k, prefix lock length p + + // DFA is used for k in {1, 2} + TEST_DO(assertString(fs, "%{1}abc", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}ABC", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}abc", "ABC", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}Abc", "abd", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}abc", "ABCD", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}abc", "abcde", Hits())); + TEST_DO(assertString(fs, "%{2}abc", "abcde", Hits().add(0))); + TEST_DO(assertString(fs, "%{2}abc", "xabcde", Hits())); + // Fallback to non-DFA matcher when k not in {1, 2} + TEST_DO(assertString(fs, "%{3}abc", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{3}abc", "XYZ", Hits().add(0))); + TEST_DO(assertString(fs, "%{3}abc", "XYZ!", Hits())); +} + +TEST("utf8 flexible searcher handles fuzzy search in cased mode") { + UTF8FlexibleStringFieldSearcher fs(0); + fs.normalize_mode(Normalizing::NONE); + TEST_DO(assertString(fs, "%{1}abc", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}abc", "Abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}ABC", "abc", Hits())); + TEST_DO(assertString(fs, "%{2}Abc", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{2}abc", "AbC", Hits().add(0))); + TEST_DO(assertString(fs, "%{3}abc", "ABC", Hits().add(0))); + TEST_DO(assertString(fs, "%{3}abc", "ABCD", Hits())); +} + +TEST("utf8 flexible searcher handles fuzzy search with prefix locking") { + UTF8FlexibleStringFieldSearcher fs(0); + // DFA + TEST_DO(assertString(fs, "%{1,4}zoid", "zoi", Hits())); + TEST_DO(assertString(fs, "%{1,4}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoid", "ZOID", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoid", Hits())); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "ZoidBerg", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "ZoidBergg", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidborg", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidblergh", Hits())); + TEST_DO(assertString(fs, "%{2,4}zoidberg", "zoidblergh", Hits().add(0))); + // Fallback + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidblergh", Hits().add(0))); + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidbooorg", Hits().add(0))); + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidzooorg", Hits())); + + fs.normalize_mode(Normalizing::NONE); + // DFA + TEST_DO(assertString(fs, "%{1,4}zoid", "ZOID", Hits())); + TEST_DO(assertString(fs, "%{1,4}ZOID", "zoid", Hits())); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidBerg", Hits().add(0))); // 1 edit + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidBblerg", Hits())); // 2 edits, 1 max + TEST_DO(assertString(fs, "%{2,4}zoidberg", "zoidBblerg", Hits().add(0))); // 2 edits, 2 max + // Fallback + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidBERG", Hits())); // 4 edits, 3 max + TEST_DO(assertString(fs, "%{4,4}zoidberg", "zoidBERG", Hits().add(0))); // 4 edits, 4 max +} + +TEST("utf8 flexible searcher fuzzy match with max_edits=0 implies exact match") { + UTF8FlexibleStringFieldSearcher fs(0); + TEST_DO(assertString(fs, "%{0}zoid", "zoi", Hits())); + TEST_DO(assertString(fs, "%{0,4}zoid", "zoi", Hits())); + TEST_DO(assertString(fs, "%{0}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{0}zoid", "ZOID", Hits().add(0))); + TEST_DO(assertString(fs, "%{0,4}zoid", "ZOID", Hits().add(0))); + fs.normalize_mode(Normalizing::NONE); + TEST_DO(assertString(fs, "%{0}zoid", "ZOID", Hits())); + TEST_DO(assertString(fs, "%{0,4}zoid", "ZOID", Hits())); + TEST_DO(assertString(fs, "%{0}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{0,4}zoid", "zoid", Hits().add(0))); +} + +TEST("utf8 flexible searcher caps oversized fuzzy prefix length to term length") { + UTF8FlexibleStringFieldSearcher fs(0); + // DFA + TEST_DO(assertString(fs, "%{1,5}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,9001}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,9001}zoid", "boid", Hits())); + // Fallback + TEST_DO(assertString(fs, "%{0,5}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{5,5}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{0,9001}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{5,9001}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{5,9001}zoid", "boid", Hits())); +} + +TEST("utf8 flexible searcher fuzzy matching treats field as 1 word") { + UTF8FlexibleStringFieldSearcher fs(0); + // Match case + TEST_DO(assertFieldInfo(fs, "%{1}foo bar baz", "foo jar baz", QTFieldInfo(0, 1, 1))); + // Mismatch case + TEST_DO(assertFieldInfo(fs, "%{1}foo", "foo bar baz", QTFieldInfo(0, 0, 1))); +} + TEST("bool search") { BoolFieldSearcher fs(0); TEST_DO(assertBool(fs, "true", true, true)); diff --git a/streamingvisitors/src/vespa/searchvisitor/matching_elements_filler.cpp b/streamingvisitors/src/vespa/searchvisitor/matching_elements_filler.cpp index 095141c0359..79bacda3f3b 100644 --- a/streamingvisitors/src/vespa/searchvisitor/matching_elements_filler.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/matching_elements_filler.cpp @@ -3,6 +3,7 @@ #include "matching_elements_filler.h" #include <vespa/searchlib/common/matching_elements.h> #include <vespa/searchlib/common/matching_elements_fields.h> +#include <vespa/searchlib/query/streaming/same_element_query_node.h> #include <vespa/searchlib/query/streaming/weighted_set_term.h> #include <vespa/vsm/searcher/fieldsearcher.h> #include <vespa/vdslib/container/searchresult.h> @@ -109,7 +110,7 @@ Matcher::add_matching_elements(const vespalib::string& field_name, uint32_t doc_ { _elements.clear(); for (auto& hit : hit_list) { - _elements.emplace_back(hit.elemId()); + _elements.emplace_back(hit.element_id()); } if (_elements.size() > 1) { std::sort(_elements.begin(), _elements.end()); diff --git a/streamingvisitors/src/vespa/searchvisitor/querywrapper.h b/streamingvisitors/src/vespa/searchvisitor/querywrapper.h index 27b5a2e12d4..b24f695196e 100644 --- a/streamingvisitors/src/vespa/searchvisitor/querywrapper.h +++ b/streamingvisitors/src/vespa/searchvisitor/querywrapper.h @@ -2,6 +2,7 @@ #pragma once +#include <vespa/searchlib/query/streaming/phrase_query_node.h> #include <vespa/searchlib/query/streaming/query.h> #include <vespa/searchlib/query/streaming/querynode.h> diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp index 6b15b7cb88e..0a64ee7c093 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp @@ -114,7 +114,7 @@ RankProcessor::initQueryEnvironment() void RankProcessor::initHitCollector(size_t wantedHitCount) { - _hitCollector.reset(new HitCollector(wantedHitCount)); + _hitCollector = std::make_unique<HitCollector>(wantedHitCount); } void @@ -209,9 +209,8 @@ class RankProgramWrapper : public HitCollector::IRankProgram { private: MatchData &_match_data; - public: - RankProgramWrapper(MatchData &match_data) : _match_data(match_data) {} + explicit RankProgramWrapper(MatchData &match_data) : _match_data(match_data) {} void run(uint32_t docid, const std::vector<search::fef::TermFieldMatchData> &matchData) override { // Prepare the match data object used by the rank program with earlier unpacked match data. copyTermFieldMatchData(matchData, _match_data); @@ -300,7 +299,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap // optimize for hitlist giving all hits for a single field in one chunk for (const Hit & hit : hitList) { - uint32_t fieldId = hit.context(); + uint32_t fieldId = hit.field_id(); if (fieldId != lastFieldId) { // reset to notfound/unknown values tmd = nullptr; @@ -335,8 +334,8 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap } if (tmd != nullptr) { // adjust so that the position for phrase terms equals the match for the first term - TermFieldMatchDataPosition pos(hit.elemId(), hit.wordpos() - term.getPosAdjust(), - hit.weight(), fieldLen); + TermFieldMatchDataPosition pos(hit.element_id(), hit.position() - term.getPosAdjust(), + hit.element_weight(), fieldLen); tmd->appendPosition(pos); LOG(debug, "Append elemId(%u),position(%u), weight(%d), tfmd.weight(%d)", pos.getElementId(), pos.getPosition(), pos.getElementWeight(), tmd->getWeight()); diff --git a/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp b/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp index cdd1a018d84..979e5f25b6a 100644 --- a/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp @@ -391,7 +391,17 @@ SearchVisitor::init(const Parameters & params) } _queryResult->getSearchResult().setWantedHitCount(wantedSummaryCount); + vespalib::stringref sortRef; + bool hasSortSpec = params.lookup("sort", sortRef); + vespalib::stringref groupingRef; + bool hasGrouping = params.lookup("aggregation", groupingRef); + if (params.lookup("rankprofile", valueRef) ) { + if ( ! hasGrouping && (wantedSummaryCount == 0)) { + // If no hits and no grouping, just use unranked profile + // TODO, optional could also include check for if grouping needs rank + valueRef = "unranked"; + } vespalib::string tmp(valueRef.data(), valueRef.size()); _rankController.setRankProfile(tmp); LOG(debug, "Received rank profile: %s", _rankController.getRankProfile().c_str()); @@ -442,9 +452,9 @@ SearchVisitor::init(const Parameters & params) if (_env) { _init_called = true; - if ( params.lookup("sort", valueRef) ) { + if ( hasSortSpec ) { search::uca::UcaConverterFactory ucaFactory; - _sortSpec = search::common::SortSpec(vespalib::string(valueRef.data(), valueRef.size()), ucaFactory); + _sortSpec = search::common::SortSpec(vespalib::string(sortRef.data(), sortRef.size()), ucaFactory); LOG(debug, "Received sort specification: '%s'", _sortSpec.getSpec().c_str()); } @@ -494,10 +504,10 @@ SearchVisitor::init(const Parameters & params) LOG(warning, "No query received"); } - if (params.lookup("aggregation", valueRef) ) { + if (hasGrouping) { std::vector<char> newAggrBlob; - newAggrBlob.resize(valueRef.size()); - memcpy(&newAggrBlob[0], valueRef.data(), newAggrBlob.size()); + newAggrBlob.resize(groupingRef.size()); + memcpy(&newAggrBlob[0], groupingRef.data(), newAggrBlob.size()); LOG(debug, "Received new aggregation blob of %zd bytes", newAggrBlob.size()); setupGrouping(newAggrBlob); } diff --git a/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h b/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h index c5bca6f3899..e339e4bdf5a 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h +++ b/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h @@ -46,7 +46,7 @@ public: explicit FieldSearcher(FieldIdT fId) noexcept : FieldSearcher(fId, false) {} FieldSearcher(FieldIdT fId, bool defaultPrefix) noexcept; ~FieldSearcher() override; - virtual std::unique_ptr<FieldSearcher> duplicate() const = 0; + [[nodiscard]] virtual std::unique_ptr<FieldSearcher> duplicate() const = 0; bool search(const StorageDocument & doc); virtual void prepare(search::streaming::QueryTermList& qtl, const SharedSearcherBuf& buf, const vsm::FieldPathMapT& field_paths, search::fef::IQueryEnvironment& query_env); @@ -106,7 +106,7 @@ protected: * For each call to onValue() a batch of words are processed, and the position is local to this batch. **/ void addHit(search::streaming::QueryTerm & qt, uint32_t pos) const { - qt.add(_words + pos, field(), _currentElementId, _currentElementWeight); + qt.add(field(), _currentElementId, _currentElementWeight, _words + pos); } public: static search::byte _foldLowCase[256]; diff --git a/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp index c0a0249125f..98e88e45b3a 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp @@ -34,7 +34,7 @@ bool StrChrFieldSearcher::matchDoc(const FieldRef & fieldRef) } } else { for (auto qt : _qtl) { - if (fieldRef.size() >= qt->termLen() || qt->isRegex()) { + if (fieldRef.size() >= qt->termLen() || qt->isRegex() || qt->isFuzzy()) { _words += matchTerm(fieldRef, *qt); } else { _words += countWords(fieldRef); @@ -49,8 +49,8 @@ size_t StrChrFieldSearcher::shortestTerm() const size_t mintsz(_qtl.front()->termLen()); for (auto it=_qtl.begin()+1, mt=_qtl.end(); it != mt; it++) { const QueryTerm & qt = **it; - if (qt.isRegex()) { - return 0; // Must avoid "too short query term" optimization when using regex + if (qt.isRegex() || qt.isFuzzy()) { + return 0; // Must avoid "too short query term" optimization when using regex or fuzzy } mintsz = std::min(mintsz, qt.termLen()); } diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp index c6deb6eacd1..5f626ccb962 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp @@ -1,5 +1,6 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "utf8flexiblestringfieldsearcher.h" +#include <vespa/searchlib/query/streaming/fuzzy_term.h> #include <vespa/searchlib/query/streaming/regexp_term.h> #include <cassert> @@ -36,7 +37,20 @@ UTF8FlexibleStringFieldSearcher::match_regexp(const FieldRef & f, search::stream if (regexp_term->regexp().partial_match({f.data(), f.size()})) { addHit(qt, 0); } - return countWords(f); + return 1; +} + +size_t +UTF8FlexibleStringFieldSearcher::match_fuzzy(const FieldRef & f, search::streaming::QueryTerm & qt) +{ + auto* fuzzy_term = qt.as_fuzzy_term(); + assert(fuzzy_term != nullptr); + // TODO delegate to matchTermExact if max edits == 0? + // - needs to avoid folding to have consistent normalization semantics + if (fuzzy_term->is_match({f.data(), f.size()})) { + addHit(qt, 0); + } + return 1; } size_t @@ -57,6 +71,9 @@ UTF8FlexibleStringFieldSearcher::matchTerm(const FieldRef & f, QueryTerm & qt) } else if (qt.isRegex()) { LOG(debug, "Use regexp match for term '%s:%s'", qt.index().c_str(), qt.getTerm()); return match_regexp(f, qt); + } else if (qt.isFuzzy()) { + LOG(debug, "Use fuzzy match for term '%s:%s'", qt.index().c_str(), qt.getTerm()); + return match_fuzzy(f, qt); } else { if (substring()) { LOG(debug, "Use substring match for term '%s:%s'", qt.index().c_str(), qt.getTerm()); diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h index cd1715ad158..a5f6ad46246 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h +++ b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h @@ -25,6 +25,7 @@ private: size_t matchTerms(const FieldRef & f, size_t shortestTerm) override; size_t match_regexp(const FieldRef & f, search::streaming::QueryTerm & qt); + size_t match_fuzzy(const FieldRef & f, search::streaming::QueryTerm & qt); public: std::unique_ptr<FieldSearcher> duplicate() const override; diff --git a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp index 9c8bb2f185a..63d2007cecf 100644 --- a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp +++ b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp @@ -133,7 +133,7 @@ FieldSearchSpec::reconfig(const QueryTerm & term) (term.isSuffix() && _arg1 != "suffix") || (term.isExactstring() && _arg1 != "exact") || (term.isPrefix() && _arg1 == "suffix") || - term.isRegex()) + (term.isRegex() || term.isFuzzy())) { _searcher = std::make_unique<UTF8FlexibleStringFieldSearcher>(id()); propagate_settings_to_searcher(); diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 5d88b2d2829..df75a6f6d1f 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -720,6 +720,20 @@ ], "fields" : [ ] }, + "com.yahoo.tensor.DirectIndexedAddress" : { + "superClass" : "java.lang.Object", + "interfaces" : [ ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public void setIndex(int, int)", + "public long getDirectIndex()", + "public long getStride(int)" + ], + "fields" : [ ] + }, "com.yahoo.tensor.IndexedDoubleTensor$BoundDoubleBuilder" : { "superClass" : "com.yahoo.tensor.IndexedTensor$BoundBuilder", "interfaces" : [ ], @@ -894,8 +908,11 @@ "public java.util.Iterator subspaceIterator(java.util.Set, com.yahoo.tensor.DimensionSizes)", "public java.util.Iterator subspaceIterator(java.util.Set)", "public varargs double get(long[])", + "public double get(com.yahoo.tensor.DirectIndexedAddress)", + "public com.yahoo.tensor.DirectIndexedAddress directAddress()", "public varargs float getFloat(long[])", "public double get(com.yahoo.tensor.TensorAddress)", + "public java.lang.Double getAsDouble(com.yahoo.tensor.TensorAddress)", "public boolean has(com.yahoo.tensor.TensorAddress)", "public abstract double get(long)", "public abstract float getFloat(long)", @@ -952,6 +969,7 @@ "public int sizeAsInt()", "public double get(com.yahoo.tensor.TensorAddress)", "public boolean has(com.yahoo.tensor.TensorAddress)", + "public java.lang.Double getAsDouble(com.yahoo.tensor.TensorAddress)", "public java.util.Iterator cellIterator()", "public java.util.Iterator valueIterator()", "public java.util.Map cells()", @@ -1032,6 +1050,7 @@ "public com.yahoo.tensor.TensorType type()", "public long size()", "public double get(com.yahoo.tensor.TensorAddress)", + "public java.lang.Double getAsDouble(com.yahoo.tensor.TensorAddress)", "public boolean has(com.yahoo.tensor.TensorAddress)", "public java.util.Iterator cellIterator()", "public java.util.Iterator valueIterator()", @@ -1158,6 +1177,7 @@ "public int sizeAsInt()", "public abstract double get(com.yahoo.tensor.TensorAddress)", "public abstract boolean has(com.yahoo.tensor.TensorAddress)", + "public abstract java.lang.Double getAsDouble(com.yahoo.tensor.TensorAddress)", "public abstract java.util.Iterator cellIterator()", "public abstract java.util.Iterator valueIterator()", "public abstract java.util.Map cells()", @@ -1454,6 +1474,9 @@ ], "methods" : [ "public void <init>(com.yahoo.tensor.TensorType$Value, java.util.Collection)", + "public boolean hasIndexedDimensions()", + "public boolean hasMappedDimensions()", + "public boolean hasOnlyIndexedBoundDimensions()", "public static varargs com.yahoo.tensor.TensorType$Value combinedValueType(com.yahoo.tensor.TensorType[])", "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", "public com.yahoo.tensor.TensorType$Value valueType()", @@ -1464,6 +1487,7 @@ "public java.util.Set dimensionNames()", "public java.util.Optional dimension(java.lang.String)", "public java.util.Optional indexOfDimension(java.lang.String)", + "public int indexOfDimensionAsInt(java.lang.String)", "public java.util.Optional sizeOfDimension(java.lang.String)", "public boolean isAssignableTo(com.yahoo.tensor.TensorType)", "public boolean isConvertibleTo(com.yahoo.tensor.TensorType)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java new file mode 100644 index 00000000000..37752361876 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java @@ -0,0 +1,38 @@ +package com.yahoo.tensor; + +/** + * Utility class for efficient access and iteration along dimensions in Indexed tensors. + * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration. + * long base = addr.getDirectIndex(); + * long stride = addr.getStride(dimension) + * i = 0...size_of_dimension + * double value = tensor.get(base + i * stride); + */ +public final class DirectIndexedAddress { + private final DimensionSizes sizes; + private final int [] indexes; + private long directIndex; + private DirectIndexedAddress(DimensionSizes sizes) { + this.sizes = sizes; + indexes = new int[sizes.dimensions()]; + directIndex = 0; + } + static DirectIndexedAddress of(DimensionSizes sizes) { + return new DirectIndexedAddress(sizes); + } + /** Sets the current index of a dimension */ + public void setIndex(int dimension, int index) { + if (index < 0 || index >= sizes.size(dimension)) { + throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">"); + } + int diff = index - indexes[dimension]; + directIndex += getStride(dimension) * diff; + indexes[dimension] = index; + } + /** Retrieve the index that can be used for direct lookup in an indexed tensor. */ + public long getDirectIndex() { return directIndex; } + /** returns the stride to be used for the given dimension */ + public long getStride(int dimension) { + return sizes.productOfDimensionsAfter(dimension); + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 1319675f5d4..f26174d9576 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -93,6 +93,10 @@ public abstract class IndexedTensor implements Tensor { return get(toValueIndex(indexes, dimensionSizes)); } + public double get(DirectIndexedAddress address) { + return get(address.getDirectIndex()); + } + public DirectIndexedAddress directAddress() { return DirectIndexedAddress.of(dimensionSizes); } /** * Returns the value at the given indexes as a float * @@ -116,6 +120,17 @@ public abstract class IndexedTensor implements Tensor { } @Override + public Double getAsDouble(TensorAddress address) { + try { + long index = toValueIndex(address, dimensionSizes, type); + if (index < 0 || size() <= index) return null; + return get(index); + } catch (IllegalArgumentException e) { + return null; + } + } + + @Override public boolean has(TensorAddress address) { try { long index = toValueIndex(address, dimensionSizes, type); @@ -160,9 +175,10 @@ public abstract class IndexedTensor implements Tensor { long valueIndex = 0; for (int i = 0; i < address.size(); i++) { - if (address.numericLabel(i) >= sizes.size(i)) + long label = address.numericLabel(i); + if (label >= sizes.size(i)) throw new IllegalArgumentException(address + " is not within the bounds of " + type); - valueIndex += sizes.productOfDimensionsAfter(i) * address.numericLabel(i); + valueIndex += sizes.productOfDimensionsAfter(i) * label; } return valueIndex; } @@ -277,7 +293,7 @@ public abstract class IndexedTensor implements Tensor { } public static Builder of(TensorType type) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type)); else return new UnboundBuilder(type); @@ -291,7 +307,7 @@ public abstract class IndexedTensor implements Tensor { * must not be further mutated by the caller */ public static Builder of(TensorType type, float[] values) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type), values); else return new UnboundBuilder(type); @@ -305,7 +321,7 @@ public abstract class IndexedTensor implements Tensor { * must not be further mutated by the caller */ public static Builder of(TensorType type, double[] values) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type), values); else return new UnboundBuilder(type); @@ -611,11 +627,11 @@ public abstract class IndexedTensor implements Tensor { private final class ValueIterator implements Iterator<Double> { - private long count = 0; + private int count = 0; @Override public boolean hasNext() { - return count < size(); + return count < sizeAsInt(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 5471ea65b97..3e0df5f2261 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -41,6 +41,9 @@ public class MappedTensor implements Tensor { public boolean has(TensorAddress address) { return cells.containsKey(address); } @Override + public Double getAsDouble(TensorAddress address) { return cells.get(address); } + + @Override public Iterator<Cell> cellIterator() { return new CellIteratorAdaptor(cells.entrySet().iterator()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 74b338fb503..95d1d70118a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -2,13 +2,15 @@ package com.yahoo.tensor; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.yahoo.tensor.impl.NumericTensorAddress; import com.yahoo.tensor.impl.StringTensorAddress; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -29,7 +31,6 @@ public class MixedTensor implements Tensor { /** The dimension specification for this tensor */ private final TensorType type; - private final int denseSubspaceSize; // XXX consider using "record" instead /** only exposed for internal use; subject to change without notice */ @@ -51,45 +52,15 @@ public class MixedTensor implements Tensor { } } - /** The cells in the tensor */ - private final List<DenseSubspace> denseSubspaces; - /** only exposed for internal use; subject to change without notice */ - public List<DenseSubspace> getInternalDenseSubspaces() { return denseSubspaces; } + public List<DenseSubspace> getInternalDenseSubspaces() { return index.denseSubspaces; } /** An index structure over the cell list */ private final Index index; - private MixedTensor(TensorType type, List<DenseSubspace> denseSubspaces, Index index) { + private MixedTensor(TensorType type, Index index) { this.type = type; - this.denseSubspaceSize = index.denseSubspaceSize(); - this.denseSubspaces = List.copyOf(denseSubspaces); this.index = index; - if (this.denseSubspaceSize < 1) { - throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize); - } - long count = 0; - for (var block : this.denseSubspaces) { - if (index.sparseMap.get(block.sparseAddress) != count) { - throw new IllegalStateException("map vs list mismatch: block #" - + count - + " address maps to #" - + index.sparseMap.get(block.sparseAddress)); - } - if (block.cells.length != denseSubspaceSize) { - throw new IllegalStateException("dense subspace size mismatch, expected " - + denseSubspaceSize - + " cells, but got: " - + block.cells.length); - } - ++count; - } - if (count != index.sparseMap.size()) { - throw new IllegalStateException("mismatch: list size is " - + count - + " but map size is " - + index.sparseMap.size()); - } } /** Returns the tensor type */ @@ -98,32 +69,34 @@ public class MixedTensor implements Tensor { /** Returns the size of the tensor measured in number of cells */ @Override - public long size() { return denseSubspaces.size() * denseSubspaceSize; } + public long size() { return index.denseSubspaces.size() * index.denseSubspaceSize; } /** Returns the value at the given address */ @Override public double get(TensorAddress address) { - int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > denseSubspaces.size()) { + var block = index.blockOf(address); + int denseOffset = index.denseOffsetOf(address); + if (block == null || denseOffset < 0 || denseOffset >= block.cells.length) { return 0.0; } + return block.cells[denseOffset]; + } + + @Override + public Double getAsDouble(TensorAddress address) { + var block = index.blockOf(address); int denseOffset = index.denseOffsetOf(address); - var block = denseSubspaces.get(blockNum); - if (denseOffset < 0 || denseOffset >= block.cells.length) { - return 0.0; + if (block == null || denseOffset < 0 || denseOffset >= block.cells.length) { + return null; } return block.cells[denseOffset]; } @Override public boolean has(TensorAddress address) { - int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > denseSubspaces.size()) { - return false; - } + var block = index.blockOf(address); int denseOffset = index.denseOffsetOf(address); - var block = denseSubspaces.get(blockNum); - return (denseOffset >= 0 && denseOffset < block.cells.length); + return (block != null && denseOffset >= 0 && denseOffset < block.cells.length); } /** @@ -136,20 +109,26 @@ public class MixedTensor implements Tensor { @Override public Iterator<Cell> cellIterator() { return new Iterator<>() { - final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator(); + final Iterator<DenseSubspace> blockIterator = index.denseSubspaces.iterator(); DenseSubspace currBlock = null; - int currOffset = denseSubspaceSize; + final long[] labels = new long[index.indexedDimensions.size()]; + int currOffset = index.denseSubspaceSize; + int prevOffset = -1; @Override public boolean hasNext() { - return (currOffset < denseSubspaceSize || blockIterator.hasNext()); + return (currOffset < index.denseSubspaceSize || blockIterator.hasNext()); } @Override public Cell next() { - if (currOffset == denseSubspaceSize) { + if (currOffset == index.denseSubspaceSize) { currBlock = blockIterator.next(); currOffset = 0; } - TensorAddress fullAddr = index.fullAddressOf(currBlock.sparseAddress, currOffset); + if (currOffset != prevOffset) { // Optimization for index.denseSubspaceSize == 1 + index.denseOffsetToAddress(currOffset, labels); + } + TensorAddress fullAddr = index.fullAddressOf(currBlock.sparseAddress, labels); + prevOffset = currOffset; double value = currBlock.cells[currOffset++]; return new Cell(fullAddr, value); } @@ -163,16 +142,16 @@ public class MixedTensor implements Tensor { @Override public Iterator<Double> valueIterator() { return new Iterator<>() { - final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator(); + final Iterator<DenseSubspace> blockIterator = index.denseSubspaces.iterator(); double[] currBlock = null; - int currOffset = denseSubspaceSize; + int currOffset = index.denseSubspaceSize; @Override public boolean hasNext() { - return (currOffset < denseSubspaceSize || blockIterator.hasNext()); + return (currOffset < index.denseSubspaceSize || blockIterator.hasNext()); } @Override public Double next() { - if (currOffset == denseSubspaceSize) { + if (currOffset == index.denseSubspaceSize) { currBlock = blockIterator.next().cells; currOffset = 0; } @@ -198,24 +177,22 @@ public class MixedTensor implements Tensor { throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" + this.type + "', requested type: '" + type + "'"); } - return new MixedTensor(other, denseSubspaces, index); + return new MixedTensor(other, index); } @Override public Tensor remove(Set<TensorAddress> addresses) { var indexBuilder = new Index.Builder(type); - List<DenseSubspace> list = new ArrayList<>(); - for (var block : denseSubspaces) { + for (var block : index.denseSubspaces) { if ( ! addresses.contains(block.sparseAddress)) { // assumption: addresses only contain the sparse part - indexBuilder.addBlock(block.sparseAddress, list.size()); - list.add(block); + indexBuilder.addBlock(block); } } - return new MixedTensor(type, list, indexBuilder.build()); + return new MixedTensor(type, indexBuilder.build()); } @Override - public int hashCode() { return Objects.hash(type, denseSubspaces); } + public int hashCode() { return Objects.hash(type, index.denseSubspaces); } @Override public String toString() { @@ -250,13 +227,14 @@ public class MixedTensor implements Tensor { /** Returns the size of dense subspaces */ public long denseSubspaceSize() { - return denseSubspaceSize; + return index.denseSubspaceSize; } /** * Base class for building mixed tensors. */ public abstract static class Builder implements Tensor.Builder { + static final int INITIAL_HASH_CAPACITY = 1000; final TensorType type; @@ -266,10 +244,11 @@ public class MixedTensor implements Tensor { * a temporary structure while finding dimension bounds. */ public static Builder of(TensorType type) { - if (type.dimensions().stream().anyMatch(d -> d instanceof TensorType.IndexedUnboundDimension)) { - return new UnboundBuilder(type); + //TODO Wire in expected map size to avoid expensive resize + if (type.hasIndexedUnboundDimensions()) { + return new UnboundBuilder(type, INITIAL_HASH_CAPACITY); } else { - return new BoundBuilder(type); + return new BoundBuilder(type, INITIAL_HASH_CAPACITY); } } @@ -307,13 +286,14 @@ public class MixedTensor implements Tensor { public static class BoundBuilder extends Builder { /** For each sparse partial address, hold a dense subspace */ - private final Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>(); + private final Map<TensorAddress, double[]> denseSubspaceMap; private final Index.Builder indexBuilder; private final Index index; private final TensorType denseSubtype; - private BoundBuilder(TensorType type) { + private BoundBuilder(TensorType type, int expectedSize) { super(type); + denseSubspaceMap = new LinkedHashMap<>(expectedSize, 0.5f); indexBuilder = new Index.Builder(type); index = indexBuilder.index(); denseSubtype = new TensorType(type.valueType(), @@ -325,10 +305,7 @@ public class MixedTensor implements Tensor { } private double[] denseSubspace(TensorAddress sparseAddress) { - if (!denseSubspaceMap.containsKey(sparseAddress)) { - denseSubspaceMap.put(sparseAddress, new double[(int)denseSubspaceSize()]); - } - return denseSubspaceMap.get(sparseAddress); + return denseSubspaceMap.computeIfAbsent(sparseAddress, (key) -> new double[(int)denseSubspaceSize()]); } public IndexedTensor.DirectIndexBuilder denseSubspaceBuilder(TensorAddress sparseAddress) { @@ -363,19 +340,20 @@ public class MixedTensor implements Tensor { @Override public MixedTensor build() { - List<DenseSubspace> list = new ArrayList<>(); - for (Map.Entry<TensorAddress, double[]> entry : denseSubspaceMap.entrySet()) { + //TODO This can be solved more efficiently with a single map. + Set<Map.Entry<TensorAddress, double[]>> entrySet = denseSubspaceMap.entrySet(); + for (Map.Entry<TensorAddress, double[]> entry : entrySet) { TensorAddress sparsePart = entry.getKey(); double[] denseSubspace = entry.getValue(); var block = new DenseSubspace(sparsePart, denseSubspace); - indexBuilder.addBlock(sparsePart, list.size()); - list.add(block); + indexBuilder.addBlock(block); } - return new MixedTensor(type, list, indexBuilder.build()); + return new MixedTensor(type, indexBuilder.build()); } public static BoundBuilder of(TensorType type) { - return new BoundBuilder(type); + //TODO Wire in expected map size to avoid expensive resize + return new BoundBuilder(type, INITIAL_HASH_CAPACITY); } } @@ -392,9 +370,9 @@ public class MixedTensor implements Tensor { private final Map<TensorAddress, Double> cells; private final long[] dimensionBounds; - private UnboundBuilder(TensorType type) { + private UnboundBuilder(TensorType type, int expectedSize) { super(type); - cells = new HashMap<>(); + cells = new LinkedHashMap<>(expectedSize, 0.5f); dimensionBounds = new long[type.dimensions().size()]; } @@ -413,7 +391,7 @@ public class MixedTensor implements Tensor { @Override public MixedTensor build() { TensorType boundType = createBoundType(); - BoundBuilder builder = new BoundBuilder(boundType); + BoundBuilder builder = new BoundBuilder(boundType, cells.size()); for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) { builder.cell(cell.getKey(), cell.getValue()); } @@ -444,7 +422,8 @@ public class MixedTensor implements Tensor { } public static UnboundBuilder of(TensorType type) { - return new UnboundBuilder(type); + //TODO Wire in expected map size to avoid expensive resize + return new UnboundBuilder(type, INITIAL_HASH_CAPACITY); } } @@ -461,8 +440,10 @@ public class MixedTensor implements Tensor { private final TensorType denseType; private final List<TensorType.Dimension> mappedDimensions; private final List<TensorType.Dimension> indexedDimensions; + private final int [] indexedDimensionsSize; private ImmutableMap<TensorAddress, Integer> sparseMap; + private List<DenseSubspace> denseSubspaces; private final int denseSubspaceSize; static private int computeDSS(List<TensorType.Dimension> dimensions) { @@ -478,17 +459,31 @@ public class MixedTensor implements Tensor { this.type = type; this.mappedDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).toList(); this.indexedDimensions = type.dimensions().stream().filter(TensorType.Dimension::isIndexed).toList(); + this.indexedDimensionsSize = new int[indexedDimensions.size()]; + for (int i = 0; i < indexedDimensions.size(); i++) { + long dimensionSize = indexedDimensions.get(i).size().orElseThrow(() -> + new IllegalArgumentException("Unknown size of indexed dimension.")); + indexedDimensionsSize[i] = (int)dimensionSize; + } + this.sparseType = createPartialType(type.valueType(), mappedDimensions); this.denseType = createPartialType(type.valueType(), indexedDimensions); this.denseSubspaceSize = computeDSS(this.indexedDimensions); + if (this.denseSubspaceSize < 1) { + throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize); + } } - int blockIndexOf(TensorAddress address) { + private DenseSubspace blockOf(TensorAddress address) { TensorAddress sparsePart = sparsePartialAddress(address); - return sparseMap.getOrDefault(sparsePart, -1); + Integer blockNum = sparseMap.get(sparsePart); + if (blockNum == null || blockNum >= denseSubspaces.size()) { + return null; + } + return denseSubspaces.get(blockNum); } - int denseOffsetOf(TensorAddress address) { + private int denseOffsetOf(TensorAddress address) { long innerSize = 1; long offset = 0; for (int i = type.dimensions().size(); --i >= 0; ) { @@ -519,38 +514,32 @@ public class MixedTensor implements Tensor { return builder.build(); } - private TensorAddress denseOffsetToAddress(long denseOffset) { + private void denseOffsetToAddress(long denseOffset, long [] labels) { if (denseOffset < 0 || denseOffset > denseSubspaceSize) { throw new IllegalArgumentException("Offset out of bounds"); } long restSize = denseOffset; long innerSize = denseSubspaceSize; - long[] labels = new long[indexedDimensions.size()]; for (int i = 0; i < labels.length; ++i) { - TensorType.Dimension dimension = indexedDimensions.get(i); - long dimensionSize = dimension.size().orElseThrow(() -> - new IllegalArgumentException("Unknown size of indexed dimension.")); - - innerSize /= dimensionSize; + innerSize /= indexedDimensionsSize[i]; labels[i] = restSize / innerSize; restSize %= innerSize; } - return TensorAddress.of(labels); } - TensorAddress fullAddressOf(TensorAddress sparsePart, long denseOffset) { - TensorAddress densePart = denseOffsetToAddress(denseOffset); + private TensorAddress fullAddressOf(TensorAddress sparsePart, long [] densePart) { String[] labels = new String[type.dimensions().size()]; int mappedIndex = 0; int indexedIndex = 0; - for (TensorType.Dimension d : type.dimensions()) { + for (int i = 0; i < type.dimensions().size(); i++) { + TensorType.Dimension d = type.dimensions().get(i); if (d.isIndexed()) { - labels[mappedIndex + indexedIndex] = densePart.label(indexedIndex); + labels[i] = NumericTensorAddress.asString(densePart[indexedIndex]); indexedIndex++; } else { - labels[mappedIndex + indexedIndex] = sparsePart.label(mappedIndex); + labels[i] = sparsePart.label(mappedIndex); mappedIndex++; } } @@ -606,8 +595,7 @@ public class MixedTensor implements Tensor { b.append(", "); // start brackets - for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) - b.append("["); + b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart()))); // value switch (type.valueType()) { @@ -620,32 +608,38 @@ public class MixedTensor implements Tensor { } // end bracket - for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) - b.append("]"); + b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd()))); } return index; } private double getDouble(int subspaceIndex, int denseOffset, MixedTensor tensor) { - return tensor.denseSubspaces.get(subspaceIndex).cells[denseOffset]; + return tensor.index.denseSubspaces.get(subspaceIndex).cells[denseOffset]; } - static class Builder { + private static class Builder { private final Index index; - private final ImmutableMap.Builder<TensorAddress, Integer> builder; + private final ImmutableMap.Builder<TensorAddress, Integer> builder = new ImmutableMap.Builder<>(); + private final ImmutableList.Builder<DenseSubspace> listBuilder = new ImmutableList.Builder<>(); + private int count = 0; Builder(TensorType type) { index = new Index(type); - builder = new ImmutableMap.Builder<>(); } - void addBlock(TensorAddress address, int sz) { - builder.put(address, sz); + void addBlock(DenseSubspace block) { + if (block.cells.length != index.denseSubspaceSize) { + throw new IllegalStateException("dense subspace size mismatch, expected " + index.denseSubspaceSize + + " cells, but got: " + block.cells.length); + } + builder.put(block.sparseAddress, count++); + listBuilder.add(block); } Index build() { index.sparseMap = builder.build(); + index.denseSubspaces = listBuilder.build(); return index; } @@ -655,27 +649,16 @@ public class MixedTensor implements Tensor { } } - private static class DenseSubspaceBuilder implements IndexedTensor.DirectIndexBuilder { - - private final TensorType type; - private final double[] values; - - public DenseSubspaceBuilder(TensorType type, double[] values) { - this.type = type; - this.values = values; - } - - @Override - public TensorType type() { return type; } + private record DenseSubspaceBuilder(TensorType type, double[] values) implements IndexedTensor.DirectIndexBuilder { @Override public void cellByDirectIndex(long index, double value) { - values[(int)index] = value; + values[(int) index] = value; } @Override public void cellByDirectIndex(long index, float value) { - values[(int)index] = value; + values[(int) index] = value; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index cc8e1602adb..d034ac551f8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -90,6 +90,8 @@ public interface Tensor { /** Returns true if this cell exists */ boolean has(TensorAddress address); + /** null = no value present. More efficient that if (t.has(key)) t.get(key) */ + Double getAsDouble(TensorAddress address); /** * Returns the cell of this in some undefined order. @@ -113,7 +115,7 @@ public interface Tensor { * @throws IllegalStateException if this does not have zero dimensions and one value */ default double asDouble() { - if (type().dimensions().size() > 0) + if (!type().dimensions().isEmpty()) throw new IllegalStateException("Require a dimensionless tensor but has " + type()); if (size() == 0) return Double.NaN; return valueIterator().next(); @@ -553,8 +555,8 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type) { - boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); + boolean containsIndexed = type.hasIndexedDimensions(); + boolean containsMapped = type.hasMappedDimensions(); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) @@ -565,8 +567,8 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type, DimensionSizes dimensionSizes) { - boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); + boolean containsIndexed = type.hasIndexedDimensions(); + boolean containsMapped = type.hasMappedDimensions(); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index f841b7757fb..1b88a5d1b2f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -3,10 +3,12 @@ package com.yahoo.tensor; import com.yahoo.tensor.impl.NumericTensorAddress; import com.yahoo.tensor.impl.StringTensorAddress; +import net.jpountz.xxhash.XXHash32; +import net.jpountz.xxhash.XXHashFactory; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Objects; -import java.util.Optional; /** * An immutable address to a tensor cell. This simply supplies a value to each dimension @@ -16,6 +18,8 @@ import java.util.Optional; */ public abstract class TensorAddress implements Comparable<TensorAddress> { + private static final XXHash32 hasher = XXHashFactory.fastestJavaInstance().hash32(); + public static TensorAddress of(String[] labels) { return StringTensorAddress.of(labels); } @@ -28,6 +32,8 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return NumericTensorAddress.of(labels); } + private int cached_hash = 0; + /** Returns the number of labels in this */ public abstract int size(); @@ -62,12 +68,17 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { @Override public int hashCode() { - int result = 1; + if (cached_hash != 0) return cached_hash; + + int hash = 0; for (int i = 0; i < size(); i++) { - if (label(i) != null) - result = 31 * result + label(i).hashCode(); + String label = label(i); + if (label != null) { + byte [] buf = label.getBytes(StandardCharsets.UTF_8); + hash = hasher.hash(buf, 0, buf.length, hash); + } } - return result; + return cached_hash = hash; } @Override @@ -138,10 +149,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { public Builder add(String dimension, String label) { Objects.requireNonNull(dimension, "dimension cannot be null"); Objects.requireNonNull(label, "label cannot be null"); - Optional<Integer> labelIndex = type.indexOfDimension(dimension); - if ( labelIndex.isEmpty()) + int labelIndex = type.indexOfDimensionAsInt(dimension); + if ( labelIndex < 0) throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'"); - labels[labelIndex.get()] = label; + labels[labelIndex] = label; return this; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index b30b664a5f7..dcfee88d599 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.google.common.collect.ImmutableSet; import com.yahoo.text.Ascii7BitMatcher; import java.util.ArrayList; @@ -86,16 +87,20 @@ public class TensorType { /** Sorted list of the dimensions of this */ private final List<Dimension> dimensions; + private final Set<String> dimensionNames; private final TensorType mappedSubtype; private final TensorType indexedSubtype; + private final int indexedUnBoundCount; // only used to initialize the "empty" instance private TensorType() { this.valueType = Value.DOUBLE; this.dimensions = List.of(); + this.dimensionNames = Set.of(); this.mappedSubtype = this; this.indexedSubtype = this; + indexedUnBoundCount = 0; } public TensorType(Value valueType, Collection<Dimension> dimensions) { @@ -103,12 +108,25 @@ public class TensorType { List<Dimension> dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); this.dimensions = List.copyOf(dimensionList); + ImmutableSet.Builder<String> namesbuilder = new ImmutableSet.Builder<>(); + int indexedBoundCount = 0, indexedUnBoundCount = 0, mappedCount = 0; + for (Dimension dimension : dimensionList) { + namesbuilder.add(dimension.name()); + Dimension.Type type = dimension.type(); + switch (type) { + case indexedUnbound -> indexedUnBoundCount++; + case indexedBound -> indexedBoundCount++; + case mapped -> mappedCount++; + } + } + this.indexedUnBoundCount = indexedUnBoundCount; + dimensionNames = namesbuilder.build(); - if (dimensionList.stream().allMatch(Dimension::isIndexed)) { + if (mappedCount == 0) { mappedSubtype = empty; indexedSubtype = this; } - else if (dimensionList.stream().noneMatch(Dimension::isIndexed)) { + else if ((indexedBoundCount + indexedUnBoundCount) == 0) { mappedSubtype = this; indexedSubtype = empty; } @@ -118,6 +136,11 @@ public class TensorType { } } + public boolean hasIndexedDimensions() { return indexedSubtype != empty; } + public boolean hasMappedDimensions() { return mappedSubtype != empty; } + public boolean hasOnlyIndexedBoundDimensions() { return !hasMappedDimensions() && ! hasIndexedUnboundDimensions(); } + boolean hasIndexedUnboundDimensions() { return indexedUnBoundCount > 0; } + static public Value combinedValueType(TensorType ... types) { List<Value> valueTypes = new ArrayList<>(); for (TensorType type : types) { @@ -161,7 +184,7 @@ public class TensorType { /** Returns an immutable set of the names of the dimensions of this */ public Set<String> dimensionNames() { - return dimensions.stream().map(Dimension::name).collect(Collectors.toSet()); + return dimensionNames; } /** Returns the dimension with this name, or empty if not present */ @@ -176,6 +199,13 @@ public class TensorType { return Optional.of(i); return Optional.empty(); } + /** Returns the 0-base index of this dimension, or empty if it is not present */ + public int indexOfDimensionAsInt(String dimension) { + for (int i = 0; i < dimensions.size(); i++) + if (dimensions.get(i).name().equals(dimension)) + return i; + return -1; + } /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */ public Optional<Long> sizeOfDimension(String dimension) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 8d8fe2b356f..866b710b72e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -134,7 +134,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return tensor; } else { // extend tensor with this dimension - if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) + if (tensor.type().hasMappedDimensions()) throw new IllegalArgumentException("Concat requires an indexed tensor, " + "but got a tensor with type " + tensor.type()); Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 3b6e03186a3..b595b1a40cd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -40,7 +40,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens @Override public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if (arguments.size() != 0) + if (!arguments.isEmpty()) throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size()); return this; } @@ -79,7 +79,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() { var result = new ArrayList<TensorFunction<NAMETYPE>>(); for (var fun : cells.values()) { - fun.asTensorFunction().ifPresent(tf -> result.add(tf)); + fun.asTensorFunction().ifPresent(result::add); } return result; } @@ -133,7 +133,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens IndexedDynamicTensor(TensorType type, List<ScalarFunction<NAMETYPE>> cells) { super(type); - if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) + if ( ! type.hasOnlyIndexedBoundDimensions()) throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " + "only indexed, bound dimensions, but this has " + type); this.cells = List.copyOf(cells); @@ -142,7 +142,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() { var result = new ArrayList<TensorFunction<NAMETYPE>>(); for (var fun : cells) { - fun.asTensorFunction().ifPresent(tf -> result.add(tf)); + fun.asTensorFunction().ifPresent(result::add); } return result; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 1ded16636d3..e0ac549651c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -114,7 +114,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) { - long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); + int joinedRank = (int)Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator<Double> aIterator = a.valueIterator(); Iterator<Double> bIterator = b.valueIterator(); IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build()); @@ -129,8 +129,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); var key = aCell.getKey(); - if (b.has(key)) { - builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + Double bVal = b.getAsDouble(key); + if (bVal != null) { + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal)); } } return builder.build(); @@ -170,7 +171,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Iterator<Tensor.Cell> superspace, long superspaceSize, boolean reversedArgumentOrder, IndexedTensor.Builder builder, DoubleBinaryOperator combinator) { - long joinedLength = Math.min(subspaceSize, superspaceSize); + int joinedLength = (int)Math.min(subspaceSize, superspaceSize); if (reversedArgumentOrder) { for (int i = 0; i < joinedLength; i++) { Tensor.Cell supercell = superspace.next(); @@ -206,11 +207,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> supercell = i.next(); TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); - if (subspace.has(subaddress)) { - double subspaceValue = subspace.get(subaddress); + Double subspaceValue = subspace.getAsDouble(subaddress); + if (subspaceValue != null) { builder.cell(supercell.getKey(), - reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) - : combinator.applyAsDouble(subspaceValue, supercell.getValue())); + reversedArgumentOrder + ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) + : combinator.applyAsDouble(subspaceValue, supercell.getValue())); } } return builder.build(); @@ -252,6 +254,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder, DoubleBinaryOperator combinator) { Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()); + int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()); DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize); @@ -263,7 +266,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP // for each combination of dimensions in a which is also in b while (aSubspace.hasNext()) { Tensor.Cell aCell = aSubspace.next(); - PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions); + PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize); // for each matching combination of dimensions ony in b for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) { Tensor.Cell bCell = bSubspace.next(); @@ -275,8 +278,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } } - private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) { - PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); + private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, + Set<String> retainDimensions, int sharedDimensionSize) { + PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize); for (int i = 0; i < addressType.dimensions().size(); i++) if (retainDimensions.contains(addressType.dimensions().get(i).name())) builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i)); @@ -331,12 +335,11 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP int[] bIndexesInJoined = mapIndexes(b.type(), joinedType); // Iterate once through the smaller tensor and construct a hash map for common dimensions - Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(); + Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt()); for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) { Tensor.Cell aCell = cellIterator.next(); TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon); - aCellsByCommonAddress.putIfAbsent(partialCommonAddress, new ArrayList<>()); - aCellsByCommonAddress.get(partialCommonAddress).add(aCell); + aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell); } // Iterate once through the larger tensor and use the hash map to find joinable cells @@ -359,7 +362,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } /** - * Returns the an array having one entry in order for each dimension of fromType + * Returns an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) @@ -368,7 +371,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP static int[] mapIndexes(TensorType fromType, TensorType toType) { int[] toIndexes = new int[fromType.dimensions().size()]; for (int i = 0; i < fromType.dimensions().size(); i++) - toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); + toIndexes[i] = toType.indexOfDimensionAsInt(fromType.dimensions().get(i).name()); return toIndexes; } @@ -390,8 +393,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) { for (int i = 0; i < from.size(); i++) { int toIndex = indexMap[i]; - if (to[toIndex] != null && ! to[toIndex].equals(from.label(i))) return false; - to[toIndex] = from.label(i); + String label = from.label(i); + if (to[toIndex] != null && ! to[toIndex].equals(label)) return false; + to[toIndex] = label; } return true; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java index 59394785382..ddad91dc060 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java @@ -121,10 +121,11 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); var key = aCell.getKey(); - if (! b.has(key)) { + Double bVal = b.getAsDouble(key); + if (bVal == null) { builder.cell(key, aCell.getValue()); } else if (combinator != null) { - builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal)); } } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index fe20c41174a..77e82b818a7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.impl.Convert; import com.yahoo.tensor.impl.StringTensorAddress; import java.util.ArrayList; @@ -136,9 +137,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey()); - ValueAggregator aggr = aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); - if (aggr == null) - aggr = aggregatingCells.get(reducedAddress); + ValueAggregator aggr = aggregatingCells.computeIfAbsent(reducedAddress, (key) ->ValueAggregator.ofType(aggregator)); aggr.aggregate(cell.getValue()); } Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); @@ -172,14 +171,15 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); ) valueAggregator.aggregate(i.next()); - return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); + return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build(); } private static Tensor reduceIndexedVector(IndexedTensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); - for (int i = 0; i < argument.dimensionSizes().size(0); i++) + int dimensionSize = Convert.safe2Int(argument.dimensionSizes().size(0)); + for (int i = 0; i < dimensionSize ; i++) valueAggregator.aggregate(argument.get(i)); - return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); + return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build(); } static abstract class ValueAggregator { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index aece782d296..2d5a0518747 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -92,11 +92,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N return false; if ( ! (a instanceof IndexedTensor)) return false; - if ( ! (a.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))) + if ( ! (a.type().hasOnlyIndexedBoundDimensions())) return false; if ( ! (b instanceof IndexedTensor)) return false; - if ( ! (b.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))) + if ( ! (b.type().hasOnlyIndexedBoundDimensions())) return false; TensorType commonDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java new file mode 100644 index 00000000000..e2cb64fdd1f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java @@ -0,0 +1,16 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.impl; + +/** + * Utility to make common conversions safe + * + * @author baldersheim + */ +public class Convert { + public static int safe2Int(long value) { + if (value > Integer.MAX_VALUE || value < Integer.MIN_VALUE) { + throw new IndexOutOfBoundsException("value = " + value + ", which is too large to fit in an int"); + } + return (int) value; + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 444ce02b14a..771b74633d9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -21,10 +21,8 @@ import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Slice; import java.util.ArrayList; -import java.util.HashSet; import java.util.Iterator; import java.util.List; -import java.util.Set; /** * Writes tensors on the JSON format used in Vespa tensor document fields: @@ -60,8 +58,7 @@ public class JsonFormat { // Short form for a single mapped dimension Cursor parent = root == null ? slime.setObject() : root.setObject("cells"); encodeSingleDimensionCells((MappedTensor) tensor, parent); - } else if (tensor instanceof MixedTensor && - tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped)) { + } else if (tensor instanceof MixedTensor && tensor.type().hasMappedDimensions()) { // Short form for a mixed tensor boolean singleMapped = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() == 1; Cursor parent = root == null ? ( singleMapped ? slime.setObject() : slime.setArray() ) @@ -204,7 +201,7 @@ public class JsonFormat { if (root.field("cells").valid() && ! primitiveContent(root.field("cells"))) decodeCells(root.field("cells"), builder); - else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed())) + else if (root.field("values").valid() && ! builder.type().hasMappedDimensions()) decodeValuesAtTop(root.field("values"), builder); else if (root.field("blocks").valid()) decodeBlocks(root.field("blocks"), builder); @@ -298,14 +295,14 @@ public class JsonFormat { /** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */ private static void decodeDirectValue(Inspector root, Tensor.Builder builder) { - boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); + boolean hasIndexed = builder.type().hasIndexedDimensions(); + boolean hasMapped = builder.type().hasMappedDimensions(); if (isArrayOfObjects(root)) decodeCells(root, builder); else if ( ! hasMapped) decodeValuesAtTop(root, builder); - else if (hasMapped && hasIndexed) + else if (hasIndexed) decodeBlocks(root, builder); else decodeCells(root, builder); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index d4b18c73f11..0a5c713f3e2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -55,8 +55,8 @@ public class TypedBinaryFormat { } private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) { - boolean hasMappedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); - boolean hasIndexedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); + boolean hasMappedDimensions = tensor.type().hasMappedDimensions(); + boolean hasIndexedDimensions = tensor.type().hasIndexedDimensions(); boolean isMixed = hasMappedDimensions && hasIndexedDimensions; // TODO: Encoding as indexed if the implementation is mixed is not yet supported so use mixed format instead diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java index 0a6c821e64e..afc95d295f0 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java @@ -9,6 +9,7 @@ import java.util.Iterator; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -96,6 +97,38 @@ public class IndexedTensorTestCase { } @Test + public void testDirectIndexedAddress() { + TensorType type = new TensorType.Builder().indexed("v", 3) + .indexed("w", wSize) + .indexed("x", xSize) + .indexed("y", ySize) + .indexed("z", zSize) + .build(); + var directAddress = DirectIndexedAddress.of(DimensionSizes.of(type)); + assertThrows(ArrayIndexOutOfBoundsException.class, () -> directAddress.getStride(5)); + assertThrows(IndexOutOfBoundsException.class, () -> directAddress.setIndex(4, 7)); + assertEquals(wSize*xSize*ySize*zSize, directAddress.getStride(0)); + assertEquals(xSize*ySize*zSize, directAddress.getStride(1)); + assertEquals(ySize*zSize, directAddress.getStride(2)); + assertEquals(zSize, directAddress.getStride(3)); + assertEquals(1, directAddress.getStride(4)); + assertEquals(0, directAddress.getDirectIndex()); + directAddress.setIndex(0,1); + assertEquals(1 * directAddress.getStride(0), directAddress.getDirectIndex()); + directAddress.setIndex(1,1); + assertEquals(1 * directAddress.getStride(0) + 1 * directAddress.getStride(1), directAddress.getDirectIndex()); + directAddress.setIndex(2,2); + directAddress.setIndex(3,2); + directAddress.setIndex(4,2); + long expected = 1 * directAddress.getStride(0) + + 1 * directAddress.getStride(1) + + 2 * directAddress.getStride(2) + + 2 * directAddress.getStride(3) + + 2 * directAddress.getStride(4); + assertEquals(expected, directAddress.getDirectIndex()); + } + + @Test public void testUnboundBuilding() { TensorType type = new TensorType.Builder().indexed("w") .indexed("v") diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index 5c4d5f1ffcf..74237a218fb 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -9,8 +9,10 @@ import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; -import java.util.*; -import java.util.stream.Collectors; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + /** * Microbenchmark of tensor operations. @@ -91,7 +93,7 @@ public class TensorFunctionBenchmark { .value(random.nextDouble()); } } - return Collections.singletonList(builder.build()); + return List.of(builder.build()); } private static TensorType vectorType(TensorType.Builder builder, String name, TensorType.Dimension.Type type, int size) { @@ -107,45 +109,51 @@ public class TensorFunctionBenchmark { public static void main(String[] args) { double time = 0; - // ---------------- Mapped with extra space (sidesteps current special-case optimizations): - // 7.8 ms - time = new TensorFunctionBenchmark().benchmark(1000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); - System.out.printf("Mapped vectors, x space time per join: %1$8.3f ms\n", time); - // 7.7 ms - time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); - System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time); + // ---------------- Indexed unbound: + time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); + System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time); + time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); + System.out.printf("Indexed unbound matrix, time per join: %1$8.3f ms\n", time); + + // ---------------- Indexed bound: + time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); + System.out.printf("Indexed bound vectors, time per join: %1$8.3f ms\n", time); + + time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); + System.out.printf("Indexed bound matrix, time per join: %1$8.3f ms\n", time); // ---------------- Mapped: - // 2.1 ms time = new TensorFunctionBenchmark().benchmark(5000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false); System.out.printf("Mapped vectors, time per join: %1$8.3f ms\n", time); - // 7.0 ms + time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false); System.out.printf("Mapped matrix, time per join: %1$8.3f ms\n", time); // ---------------- Indexed (unbound) with extra space (sidesteps current special-case optimizations): - // 14.5 ms time = new TensorFunctionBenchmark().benchmark(500, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed vectors, x space time per join: %1$8.3f ms\n", time); - // 8.9 ms time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed matrix, x space time per join: %1$8.3f ms\n", time); - // ---------------- Indexed unbound: - // 0.14 ms - time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); - System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time); - // 0.44 ms - time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); - System.out.printf("Indexed unbound matrix, time per join: %1$8.3f ms\n", time); + // ---------------- Mapped with extra space (sidesteps current special-case optimizations): + time = new TensorFunctionBenchmark().benchmark(1000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); + System.out.printf("Mapped vectors, x space time per join: %1$8.3f ms\n", time); - // ---------------- Indexed bound: - // 0.32 ms - time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); - System.out.printf("Indexed bound vectors, time per join: %1$8.3f ms\n", time); - // 0.44 ms - time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); - System.out.printf("Indexed bound matrix, time per join: %1$8.3f ms\n", time); + time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); + System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time); + + /* 2.4Ghz Intel Core i9, Macbook Pro 2019 + * Indexed unbound vectors, time per join: 0,067 ms + * Indexed unbound matrix, time per join: 0,107 ms + * Indexed bound vectors, time per join: 0,068 ms + * Indexed bound matrix, time per join: 0,105 ms + * Mapped vectors, time per join: 1,342 ms + * Mapped matrix, time per join: 3,448 ms + * Indexed vectors, x space time per join: 6,398 ms + * Indexed matrix, x space time per join: 3,220 ms + * Mapped vectors, x space time per join: 14,984 ms + * Mapped matrix, x space time per join: 19,873 ms + */ } } diff --git a/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp b/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp index 2f09e331c5d..c07743554ea 100644 --- a/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp +++ b/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp @@ -40,12 +40,14 @@ struct MyInvCmp { struct Timer { double minTime; - vespalib::Timer timer; - Timer() : minTime(1.0e10), timer() {} - void start() { timer = vespalib::Timer(); } + vespalib::steady_time start_time; + Timer() : minTime(1.0e10), start_time() {} + void start() { + start_time = vespalib::steady_clock::now(); + } void stop() { - double ms = vespalib::count_ms(timer.elapsed()); - minTime = std::min(minTime, ms); + std::chrono::duration<double,std::milli> elapsed = vespalib::steady_clock::now() - start_time; + minTime = std::min(minTime, elapsed.count()); } }; diff --git a/vespalib/src/vespa/fastlib/text/unicodeutil.cpp b/vespalib/src/vespa/fastlib/text/unicodeutil.cpp index e29b91d6522..bd4ff5d93a9 100644 --- a/vespalib/src/vespa/fastlib/text/unicodeutil.cpp +++ b/vespalib/src/vespa/fastlib/text/unicodeutil.cpp @@ -11,15 +11,15 @@ namespace { class Initialize { public: - Initialize() { Fast_UnicodeUtil::InitTables(); } + Initialize() noexcept { Fast_UnicodeUtil::InitTables(); } }; -Initialize _G_Initializer; +Initialize _g_initializer; } void -Fast_UnicodeUtil::InitTables() +Fast_UnicodeUtil::InitTables() noexcept { /** * Hack for Katakana accent marks (torgeir) @@ -29,8 +29,7 @@ Fast_UnicodeUtil::InitTables() } char * -Fast_UnicodeUtil::utf8ncopy(char *dst, const ucs4_t *src, - int maxdst, int maxsrc) +Fast_UnicodeUtil::utf8ncopy(char *dst, const ucs4_t *src, int maxdst, int maxsrc) noexcept { char * p = dst; char * edst = dst + maxdst; @@ -83,7 +82,7 @@ Fast_UnicodeUtil::utf8ncopy(char *dst, const ucs4_t *src, int -Fast_UnicodeUtil::utf8cmp(const char *s1, const ucs4_t *s2) +Fast_UnicodeUtil::utf8cmp(const char *s1, const ucs4_t *s2) noexcept { ucs4_t i1; ucs4_t i2; @@ -101,7 +100,7 @@ Fast_UnicodeUtil::utf8cmp(const char *s1, const ucs4_t *s2) } size_t -Fast_UnicodeUtil::ucs4strlen(const ucs4_t *str) +Fast_UnicodeUtil::ucs4strlen(const ucs4_t *str) noexcept { const ucs4_t *p = str; while (*p++ != 0) { @@ -111,7 +110,7 @@ Fast_UnicodeUtil::ucs4strlen(const ucs4_t *str) } ucs4_t * -Fast_UnicodeUtil::ucs4copy(ucs4_t *dst, const char *src) +Fast_UnicodeUtil::ucs4copy(ucs4_t *dst, const char *src) noexcept { ucs4_t i; ucs4_t *p; @@ -127,7 +126,7 @@ Fast_UnicodeUtil::ucs4copy(ucs4_t *dst, const char *src) } ucs4_t -Fast_UnicodeUtil::GetUTF8CharNonAscii(unsigned const char *&src) +Fast_UnicodeUtil::GetUTF8CharNonAscii(unsigned const char *&src) noexcept { ucs4_t retval; @@ -222,7 +221,7 @@ Fast_UnicodeUtil::GetUTF8CharNonAscii(unsigned const char *&src) } ucs4_t -Fast_UnicodeUtil::GetUTF8Char(unsigned const char *&src) +Fast_UnicodeUtil::GetUTF8Char(unsigned const char *&src) noexcept { return (*src >= 0x80) ? GetUTF8CharNonAscii(src) @@ -246,7 +245,7 @@ Fast_UnicodeUtil::GetUTF8Char(unsigned const char *&src) #define UTF8_STARTCHAR(c) (!((c) & 0x80) || ((c) & 0x40)) int Fast_UnicodeUtil::UTF8move(unsigned const char* start, size_t length, - unsigned const char*& pos, off_t offset) + unsigned const char*& pos, off_t offset) noexcept { int increment = offset > 0 ? 1 : -1; unsigned const char* p = pos; diff --git a/vespalib/src/vespa/fastlib/text/unicodeutil.h b/vespalib/src/vespa/fastlib/text/unicodeutil.h index 87c09826948..740cc9381b7 100644 --- a/vespalib/src/vespa/fastlib/text/unicodeutil.h +++ b/vespalib/src/vespa/fastlib/text/unicodeutil.h @@ -16,7 +16,7 @@ using ucs4_t = uint32_t; * Used to examine properties of unicode characters, and * provide fast conversion methods between often used encodings. */ -class Fast_UnicodeUtil { +class Fast_UnicodeUtil final { private: /** * Is true when the tables have been initialized. Is set by @@ -46,9 +46,8 @@ private: }; public: - virtual ~Fast_UnicodeUtil() { } /** Initialize the ISO 8859-1 static tables. */ - static void InitTables(); + static void InitTables() noexcept; /** Indicates an invalid UTF-8 character sequence. */ enum { _BadUTF8Char = 0xfffffffeu }; @@ -64,7 +63,7 @@ public: * one or more of the properties alphabetic, ideographic, * combining char, decimal digit char, private use, extender. */ - static bool IsWordChar(ucs4_t testchar) { + static bool IsWordChar(ucs4_t testchar) noexcept { return (testchar < 65536 && (_compCharProps[testchar >> 8][testchar & 255] & _wordcharProp) != 0); @@ -80,8 +79,8 @@ public: * @return The next UCS4 character, or _BadUTF8Char if the * next character is invalid. */ - static ucs4_t GetUTF8Char(const unsigned char *& src); - static ucs4_t GetUTF8Char(const char *& src) { + static ucs4_t GetUTF8Char(const unsigned char *& src) noexcept; + static ucs4_t GetUTF8Char(const char *& src) noexcept { const unsigned char *temp = reinterpret_cast<const unsigned char *>(src); ucs4_t res = GetUTF8Char(temp); src = reinterpret_cast<const char *>(temp); @@ -94,7 +93,7 @@ public: * @param i The UCS4 character. * @return Pointer to the next position in dst after the putted byte(s). */ - static char *utf8cput(char *dst, ucs4_t i) { + static char *utf8cput(char *dst, ucs4_t i) noexcept { if (i < 128) *dst++ = i; else if (i < 0x800) { @@ -132,14 +131,14 @@ public: * @param src The UTF-8 source buffer. * @return A pointer to the destination string. */ - static ucs4_t *ucs4copy(ucs4_t *dst, const char *src); + static ucs4_t *ucs4copy(ucs4_t *dst, const char *src) noexcept; /** * Get the length of the UTF-8 representation of an UCS4 character. * @param i The UCS4 character. * @return The number of bytes required for the UTF-8 representation. */ - static size_t utf8clen(ucs4_t i) { + static size_t utf8clen(ucs4_t i) noexcept { if (i < 128) return 1; else if (i < 0x800) @@ -159,7 +158,7 @@ public: * @param testchar The character to lowercase. * @return The lowercase of the input, if defined. Else the input character. */ - static ucs4_t ToLower(ucs4_t testchar) + static ucs4_t ToLower(ucs4_t testchar) noexcept { ucs4_t ret; if (testchar < 65536) { @@ -182,14 +181,14 @@ public: * @return Number of bytes moved, or -1 if out of range */ static int UTF8move(unsigned const char* start, size_t length, - unsigned const char*& pos, off_t offset); + unsigned const char*& pos, off_t offset) noexcept; /** * Find the number of characters in an UCS4 string. * @param str The UCS4 string. * @return The number of characters. */ - static size_t ucs4strlen(const ucs4_t *str); + static size_t ucs4strlen(const ucs4_t *str) noexcept; /** * Convert UCS4 to UTF-8, bounded by max lengths. @@ -199,7 +198,7 @@ public: * @param maxsrc The maximum number of characters to convert from src. * @return A pointer to the destination. */ - static char *utf8ncopy(char *dst, const ucs4_t *src, int maxdst, int maxsrc); + static char *utf8ncopy(char *dst, const ucs4_t *src, int maxdst, int maxsrc) noexcept; /** @@ -210,7 +209,7 @@ public: * if s1 is, respectively, less than, matching, or greater than s2. * NB Only used in local test */ - static int utf8cmp(const char *s1, const ucs4_t *s2); + static int utf8cmp(const char *s1, const ucs4_t *s2) noexcept; /** * Test for terminal punctuation. @@ -218,7 +217,7 @@ public: * @return true if testchar is a terminal punctuation character, * i.e. if it has the terminal punctuation char property. */ - static bool IsTerminalPunctuationChar(ucs4_t testchar) { + static bool IsTerminalPunctuationChar(ucs4_t testchar) noexcept { return (testchar < 65536 && (_compCharProps[testchar >> 8][testchar & 255] & _terminalPunctuationCharProp) != 0); @@ -233,10 +232,10 @@ public: * @return The next UCS4 character, or _BadUTF8Char if the * next character is invalid. */ - static ucs4_t GetUTF8CharNonAscii(unsigned const char *&src); + static ucs4_t GetUTF8CharNonAscii(unsigned const char *&src) noexcept; // this is really an alias of the above function - static ucs4_t GetUTF8CharNonAscii(const char *&src) { + static ucs4_t GetUTF8CharNonAscii(const char *&src) noexcept { unsigned const char *temp = reinterpret_cast<unsigned const char *>(src); ucs4_t res = GetUTF8CharNonAscii(temp); src = reinterpret_cast<const char *>(temp); |